or ).
if isinstance(descendant, bs4.Tag):
if descendant.name in paragraph_tags:
if descendant.find_all(paragraph_tags):
# If there are nested paragraph tags, don't treat it as a single
# contiguous tag.
continue
skip_children = list(descendant.descendants)
text = " ".join(descendant.get_text(" ", strip=True).split())
if text:
yield text
continue
if (isinstance(descendant, bs4.Comment) or
not isinstance(descendant, bs4.NavigableString)):
continue
text = " ".join(descendant.strip().split())
if text:
yield text
================================================
FILE: tensor2tensor/data_generators/wikisum/parallel_launch.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
r"""Launch a script in parallel on GCP.
For each instance (`--num_instances`), the script will copy the code in
`--code_dir` to the instance, run `--setup_command` and then run
`--command_prefix` joined with the task's id or a line in
`--per_instance_suffix_file`.
Note that the machines will attempt to down themselves on completion or failure.
If they do not, you can delete them manually or use delete_instances.sh to
delete many at once.
Example usage:
```
BUCKET=gs://my-bucket
python parallel_launch.py \
--num_instances=1000 \
--cpu=4 --mem=4 \
--name=wikisum-refs-web \
--code_dir=./ \
--log_dir=$BUCKET/refs_logs \
--setup_command="pip3 install aiohttp cchardet aiodns bs4 -q --user" \
--command_prefix="python3 wikisum/get_references_web.py --out_dir=$BUCKET/wiki_references --shard_id"
```
"""
# pylint: enable=line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import multiprocessing as mp
import os
import socket
import subprocess as sp
import time
from tensor2tensor.utils import cloud_mlengine as cloud
import tensorflow.compat.v1 as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_instances", None, "Number of instances to launch.")
flags.DEFINE_string("name", None, "Instance name prefix.")
flags.DEFINE_string("log_dir", None, "GCS bucket to copy logs out to.")
flags.DEFINE_string("code_dir", None, "Directory to copy.")
flags.DEFINE_string("setup_command", None, "Setup command to run.")
flags.DEFINE_string("command_prefix", None, "Command to run, prefix.")
flags.DEFINE_string("per_instance_suffix_file", None,
"Command to run, suffix per instance. If None, suffix will "
"be instance id.")
flags.DEFINE_integer("cpu", 1, "Number of CPUs per instance.")
flags.DEFINE_integer("mem", 4, "Memory in GB per instance.")
flags.DEFINE_integer("num_threads", 48,
"Number of threads to use to spin up jobs.")
flags.DEFINE_bool("debug_keep_up", False,
"If True, will keep the machine up. num_instances must be 1.")
flags.DEFINE_string("instance_ids", None,
"Comma-separated list of integer instance ids to launch. "
"Useful if some failed on a previous run and you only want "
"to rerun specific tasks.")
DELETE = "gcloud compute instances delete {name}"
DELETE_SELF = ("gcloud compute instances delete $(hostname) --quiet "
"--zone={zone}")
CREATE_INSTANCE = ("gcloud compute instances create {instance_name} "
"--custom-cpu {cpu} --custom-memory {mem} "
"--custom-extensions "
"--image-project=ml-images --image-family=tf-1-7 "
"--scopes=cloud-platform")
COPY_CODE = "gcloud compute scp --recurse {local_dir} {instance_name}:~/"
SSH = "gcloud compute ssh {instance_name} --command"
SCREEN = "screen -dmS test bash -c \"{command}\""
DEFAULT_ZONE = "gcloud config get-value compute/zone"
LOGS = "> ~/logs-{task_id}.txt 2>&1; gsutil cp ~/logs-{task_id}.txt {bucket}"
def remote_run(cmd, instance_name, detach=False, retries=1):
"""Run command on GCS instance, optionally detached."""
if detach:
cmd = SCREEN.format(command=cmd)
args = SSH.format(instance_name=instance_name).split()
args.append(cmd)
for i in range(retries + 1):
try:
if i > 0:
tf.logging.info("Retry %d for %s", i, args)
return sp.check_call(args)
except sp.CalledProcessError as e:
if i == retries:
raise e
def default_zone():
return cloud.shell_output(DEFAULT_ZONE).strip()
@contextlib.contextmanager
def safe_socket(timeout=2):
s = socket.socket()
s.settimeout(timeout)
try:
yield s
finally:
s.close()
def wait_for_ssh(ip):
"""Wait for SSH to be available at given IP address."""
for _ in range(12):
with safe_socket() as s:
try:
s.connect((ip, 22))
return True
except socket.timeout:
pass
time.sleep(10)
return False
def create_instance(instance_name, cpu=1, mem=4):
tf.logging.info("Creating instance %s", instance_name)
out = cloud.shell_output(CREATE_INSTANCE, instance_name=instance_name,
cpu=cpu, mem=mem)
return out.split("\n")[1:-1][0].split()[8]
def list_vm_names_and_ips():
list_out = cloud.shell_output(cloud.LIST_VM)
lines = [l.split() for l in list_out.split("\n")[1:-1]]
names_and_ips = [(l[0].strip(), l[-2].strip()) for l in lines]
return names_and_ips
def shell_run_with_retry(cmd, retries=1, **kwargs):
for i in range(retries + 1):
try:
if i > 0:
tf.logging.info("Retry %d for %s", i, cmd)
cloud.shell_run(cmd, **kwargs)
return
except sp.CalledProcessError as e:
if i == retries:
raise e
def delete_instance(instance_name):
cloud.shell_run(DELETE, name=instance_name)
def launch_instance(instance_name,
command,
existing_ip=None,
cpu=1,
mem=4,
code_dir=None,
setup_command=None):
"""Launch a GCE instance."""
# Create instance
ip = existing_ip or create_instance(instance_name, cpu=cpu, mem=mem)
tf.logging.info("Waiting for SSH %s", instance_name)
ready = wait_for_ssh(ip)
if not ready:
raise ValueError("Instance %s never ready for SSH" % instance_name)
# Copy code
if code_dir:
shell_run_with_retry(COPY_CODE, retries=2,
local_dir=code_dir, instance_name=instance_name)
# Run setup
if setup_command:
tf.logging.info("Running setup on %s", instance_name)
remote_run(setup_command, instance_name)
# Run command
tf.logging.info("Running command on %s", instance_name)
remote_run(command, instance_name, detach=True)
def main(_):
assert FLAGS.num_instances
assert FLAGS.name
zone = default_zone()
assert zone
code_dir = None
if FLAGS.code_dir:
code_dir = os.path.abspath(os.path.expanduser(FLAGS.code_dir))
# Suffixes per instance
if FLAGS.per_instance_suffix_file:
with tf.gfile.Open(FLAGS.per_instance_suffix_file) as f:
suffixes = [l.strip() for l in f.readlines()]
else:
suffixes = list(range(FLAGS.num_instances))
assert len(suffixes) == FLAGS.num_instances
vm_info = list_vm_names_and_ips()
vm_names = list(zip(*vm_info))[0] if vm_info else []
pool = mp.Pool(FLAGS.num_threads)
async_results = []
assert FLAGS.log_dir
log_dir = os.path.join(FLAGS.log_dir, FLAGS.name)
tf.gfile.MakeDirs(log_dir)
assert log_dir.startswith("gs://")
if not log_dir.endswith("/"):
log_dir += "/"
# Write a test file to make sure gcloud GCS APIs are enabled
test_filename = os.path.join(log_dir, "check_write")
with tf.gfile.Open(test_filename, "w") as f:
f.write("testing GCS write")
tf.gfile.Remove(test_filename)
instance_ids = list(range(FLAGS.num_instances))
if FLAGS.instance_ids:
instance_ids = [int(i) for i in FLAGS.instance_ids.split(",")]
tf.logging.info("Launching %d instances", len(instance_ids))
for i in instance_ids:
instance_name = "%s-%d" % (FLAGS.name, i)
existing_ip = (vm_info[vm_names.index(instance_name)][1]
if instance_name in vm_names else None)
logging = LOGS.format(task_id=i, bucket=log_dir) if log_dir else ""
delete = DELETE_SELF.format(zone=zone)
if FLAGS.debug_keep_up:
assert len(instance_ids) == 1
delete = ""
command = "{prefix} {suffix} {logging}; {delete}".format(
prefix=FLAGS.command_prefix,
suffix=suffixes[i],
delete=delete,
logging=logging)
args = (instance_name, command, existing_ip,
FLAGS.cpu, FLAGS.mem, code_dir,
FLAGS.setup_command)
res = pool.apply_async(launch_instance, args)
async_results.append((res, instance_name, i))
failed = []
for res, instance_name, i in async_results:
try:
res.get()
except Exception as e: # pylint: disable=broad-except
failed.append((instance_name, i))
tf.logging.error("Failed to launch task %s due to exception %s",
instance_name, str(e))
results = []
if failed:
ids_for_flag = ",".join([str(i) for i in list(zip(*failed))[1]])
tf.logging.error("Failed to launch %d jobs. Tasks: %s. "
"Attempting delete in case they are still up. Rerun with "
"--instance_ids='%s' to attempt relaunch.",
len(failed), str(failed), ids_for_flag)
for instance_name, _ in failed:
res = pool.apply_async(delete_instance, (instance_name,))
results.append(res)
for res in results:
try:
res.get()
except: # pylint: disable=bare-except
pass
tf.logging.info("Launching complete.")
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
================================================
FILE: tensor2tensor/data_generators/wikisum/produce_examples.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Produce examples given a vocab, wikis, references, and dataset URLs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves import range
from tensor2tensor.data_generators.wikisum import utils
from tensor2tensor.data_generators.wikisum import wikisum
import tensorflow.compat.v1 as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_tasks", 1000, "Number of parallel tasks.")
flags.DEFINE_integer("task_id", 0, "Task id in a parallel run.")
flags.DEFINE_string("out_dir", None, "Directory to write to.")
flags.DEFINE_string("wikis_dir",
"gs://tensor2tensor-data/wikisum/wiki_content/",
"Directory with wiki_content.tfrecords.")
flags.DEFINE_string("refs_dir", None, "Directory with process_X dirs")
flags.DEFINE_string("urls_dir", "gs://tensor2tensor-data/wikisum/wiki_urls/",
"Directory with wiki_urls.json")
flags.DEFINE_string("vocab_dir", None, "Directory with vocab file")
flags.DEFINE_bool("for_commoncrawl", False,
"Whether to use WikisumCommoncrawl or WikisumWeb.")
def main(_):
if FLAGS.for_commoncrawl:
problem = wikisum.WikisumCommoncrawl()
else:
problem = wikisum.WikisumWeb()
out_filepaths = problem.out_filepaths(FLAGS.out_dir)
out_filepaths = utils.shard(out_filepaths, FLAGS.num_tasks)[FLAGS.task_id]
if not FLAGS.vocab_dir:
FLAGS.vocab_dir = FLAGS.out_dir
shard_ids = utils.shard(list(range(utils.NUM_SHARDS)),
FLAGS.num_tasks)[FLAGS.task_id]
with utils.timing("produce_examples"):
wikisum.produce_examples(
shard_ids=shard_ids,
wikis_dir=FLAGS.wikis_dir,
refs_dir=FLAGS.refs_dir,
urls_dir=FLAGS.urls_dir,
vocab_path=os.path.join(FLAGS.vocab_dir, problem.vocab_filename),
out_filepaths=out_filepaths)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
================================================
FILE: tensor2tensor/data_generators/wikisum/test_data/para_bad1.txt
================================================
kolkata ward no 97 37
you are here : india » west bengal » kolkata » kolkata
this paragraph too short
a | b | c | d | e | f | g | h | i | j | k | l | m | n | o | p | q | r | s | t | u | v | w | x | y | z
123 123 123 123 985 9880 1230 0980 . 12398 .
- 5 . 7 % - 5 . 2 % - 15 . 1 % 4 . 7 % - 13 . 3 %
http : / / www . bbc . co . uk / sport / football / 24351521
no . - 26 beadon street .
{ { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } } { { # playpopup } } { { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } } { { genre } }
denham , samuel coulter , sally 133 oct 28 1819
browse by
================================================
FILE: tensor2tensor/data_generators/wikisum/test_data/para_good1.txt
================================================
this is a very good paragraph . it even has two sentences .
the castle that was soon to figure so largely in lee’s life lay fourteen miles
to the southwest of where he sat perched atop his tank . topped with storybook
crenelations and accompanied by a rich history , schloss itter , as it’s called
in german , was first mentioned in land records as early as 1240 . since then ,
itter has passed through a number of hands . after germany’s march 1938
annexation of austria , the castle’s robust construction and relatively remote
location attracted the attention of the notoriously secretive nazis . within
months of absorbing austria into the greater reich , the german government
requisitioned castle itter for unspecified “official use”—which included housing
for several months in 1942 an organization called the “german association for
combating the dangers of tobacco . ” on february 7 , 1943 , it fell into new
hands yet again , for on that day , the structure and all its outbuildings were
requisitioned by the wehrmacht on behalf of the ss .
the url for the site is http : / / www . bbc . co . uk / sport / football / 24351521 .
================================================
FILE: tensor2tensor/data_generators/wikisum/utils.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikisum data generation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import datetime
import gzip
import os
import re
import urllib
import tensorflow.compat.v1 as tf
# pylint: disable=g-import-not-at-top
# To maintain compatibility with Python 2 and 3
try:
import cStringIO as StringIO
except ImportError:
import io as StringIO
# pylint: enable=g-import-not-at-top
# Each entry is a URL to the wet.paths.gz file for that CommonCrawl dump.
WET_PATHS_BY_DATE = {
'0917': ('https://commoncrawl.s3.amazonaws.com/crawl-data/CC-MAIN-2017-39/'
'wet.paths.gz'),
}
S3_HTTP_PREFIX = 'https://commoncrawl.s3.amazonaws.com/'
NUM_SHARDS = 1000
METADTA_SUFFIX = '.metadata.json'
def readahead(path):
return path
class WETHeader(collections.namedtuple('WETHeader', ['url', 'length'])):
URI_HEADER = 'WARC-Target-URI: '
LENGTH_HEADER = 'Content-Length: '
@classmethod
def read(cls, f):
"""Read header from file. Headers end with length and then 1 blank line."""
url = None
line = f.readline()
if not line:
# EOF
return None
while not line.startswith(cls.LENGTH_HEADER):
if line.startswith(cls.URI_HEADER):
url = line[len(cls.URI_HEADER):].strip()
line = f.readline()
# Consume empty separator
f.readline()
# Read content
length = int(line.split(':')[1])
return cls(url, length)
class WETRecord(collections.namedtuple('WETRecord', ['url', 'content'])):
@classmethod
def read(cls, f):
"""Read WETRecord from file. Records end with 2 blank lines."""
header = WETHeader.read(f)
if header is None:
# EOF
return None
content = f.read(header.length)
# Consume empty separators
f.readline()
f.readline()
return cls(header.url, content)
def wet_records_from_file_obj(f, take_ownership=False):
"""Iterate through records in WET file object."""
while True:
record = WETRecord.read(f)
if record is None:
break
if not record.url:
continue
yield record
if take_ownership:
f.close()
def wet_records(wet_filepath):
"""Generate WETRecords from filepath."""
if wet_filepath.endswith('.gz'):
fopen = gzip.open
else:
fopen = tf.gfile.GFile
with fopen(wet_filepath) as f:
for record in wet_records_from_file_obj(f):
yield record
def download(url, download_dir):
outname = os.path.join(download_dir, os.path.basename(url))
if tf.gfile.Exists(outname):
print('Found %s, skipping download' % outname)
return outname
inprogress = outname + '.incomplete'
print('Downloading %s' % url)
inprogress, _ = urllib.urlretrieve(url, inprogress)
tf.gfile.Rename(inprogress, outname)
return outname
def wet_download_urls(wet_paths_url, tmp_dir, rm_after=True):
paths_gz = download(wet_paths_url, tmp_dir)
with gzip.open(paths_gz) as f:
path = f.readline()
while path:
download_path = S3_HTTP_PREFIX + path[:-1]
yield download_path
path = f.readline()
if rm_after:
tf.gfile.Remove(paths_gz)
def wet_records_from_url(download_url, tmp_dir, rm_after=True):
wet_gz = download(download_url, tmp_dir)
try:
for wet_record in wet_records(wet_gz):
yield wet_record
finally:
if rm_after:
tf.gfile.Remove(wet_gz)
class DummyPool(object):
def __init__(self, processes=None):
pass
def apply_async(self, fn, args=None):
args = args or tuple()
return DummyResult(fn(*args))
def map(self, fn, arg_list):
return [fn(a) for a in arg_list]
class DummyResult(object):
def __init__(self, result):
self.result = result
def get(self):
return self.result
def shard(items, num_shards):
"""Split items into num_shards groups."""
sharded = []
num_per_shard = len(items) // num_shards
start = 0
for _ in range(num_shards):
sharded.append(items[start:start + num_per_shard])
start += num_per_shard
remainder = len(items) % num_shards
start = len(items) - remainder
for i in range(remainder):
sharded[i].append(items[start + i])
assert sum([len(fs) for fs in sharded]) == len(items)
return sharded
def gzip_memfile(fname):
with tf.gfile.Open(readahead(fname)) as f:
memfile = StringIO.StringIO(f.read())
return gzip.GzipFile(fileobj=memfile)
_SOME_ALPHA_RE = re.compile(r'[A-Za-z]+')
_ONLY_ALPHA_RE = re.compile(r'^[A-Za-z]*$')
def filter_paragraph(p):
"""Simple filter to remove obviously bad paragraphs (bad text extraction).
Note this needs to run very quickly as it is applied to every paragraph
in the corpus, so nothing fancy! This whole method should be linear
expected time in len(p).
Args:
p: string, paragraph
Returns:
True if we should remove the paragraph.
"""
# Expect a minimum number of words.
tokens = p.split()
if len(tokens) < 6:
return True
# Require some letters.
if not re.search(_SOME_ALPHA_RE, p):
return True
# Keep this one at the end, probably the most complicated logic.
# We try to detect sentences, which should have a minimum of 3 tokens
# with only alphabetic characters.
last = 0
found_sentence = False
num_alpha = 0
for i, x in enumerate(tokens):
if x == '.':
if i - last > 3 and num_alpha >= 3:
found_sentence = True
break
last = i
num_alpha = 0
if re.match(_ONLY_ALPHA_RE, x):
num_alpha += 1
if not found_sentence:
return True
return False
@contextlib.contextmanager
def timing(name=''):
"""Log start, end, and duration."""
start = datetime.datetime.now()
timestamp = start.strftime('%H:%M')
tf.logging.info('Starting job [%s] at %s', name, timestamp)
yield
end = datetime.datetime.now()
timestamp = end.strftime('%H:%M')
tf.logging.info('Finished job [%s] at %s', name, timestamp)
duration = end - start
duration_mins = duration.total_seconds() / 60
tf.logging.info('Total time [%s] (m): %d', name, int(duration_mins))
================================================
FILE: tensor2tensor/data_generators/wikisum/utils_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.data_generators.wikisum.utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensor2tensor.data_generators.wikisum import utils
import tensorflow.compat.v1 as tf
pkg_dir = os.path.abspath(__file__)
pkg_dir, _ = os.path.split(pkg_dir)
_TESTDATA = os.path.join(pkg_dir, "test_data")
def _get_testdata(filename):
with tf.io.gfile.GFile(filename) as f:
return f.read()
class UtilsTest(tf.test.TestCase):
def test_filter_paragraph(self):
for bad in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_bad*.txt")):
for p in _get_testdata(bad).split("\n"):
self.assertTrue(utils.filter_paragraph(p),
msg="Didn't filter %s" % p)
for good in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_good*.txt")):
for p in _get_testdata(good).split("\n"):
p = _get_testdata(good)
self.assertFalse(utils.filter_paragraph(p), msg="Filtered %s" % p)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/data_generators/wikisum/validate_data.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Aggregate stats from produce_examples."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import numpy as np
import six
from six.moves import zip
from tensor2tensor.data_generators.wikisum import wikisum
import tensorflow.compat.v1 as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("out_dir", None, "Directory with data and stats files.")
flags.DEFINE_bool("for_commoncrawl", False,
"Whether to use WikisumCommoncrawl or WikisumWeb.")
flags.DEFINE_bool("rm_per_shard_stats", True,
"Whether to remove the per-shard stats files after writing "
"out the aggregated stats.")
def aggregate_stats(stats_files):
"""Aggregate stats in per-shard stats files."""
all_stats = {}
for fname in stats_files:
with tf.gfile.Open(fname) as f:
stats = json.loads(f.read())
for k, v in six.iteritems(stats):
if k not in all_stats:
if isinstance(v, list):
all_stats[k] = []
else:
all_stats[k] = 0
if isinstance(v, list):
all_stats[k].extend(v)
else:
all_stats[k] += v
stats = all_stats
ref_coverage = float(stats["total_found_refs"]) / stats["total_original_refs"]
len_bounds = [0, 2, 10, 100, 1000, 5000, 10000, 20000, 50000, 100000, 1000000]
len_counts, len_bounds = np.histogram(stats["ref_lengths"], len_bounds)
len_dist = len_counts.astype(np.float32) / len_counts.sum()
wiki_coverage = (float(stats["num_wikis_written"]) /
stats["total_original_wikis"])
wikis_skipped_no_ref = (float(stats["wikis_skipped_no_refs"]) /
stats["total_original_wikis"])
wikis_skipped_no_lead = (float(stats["wikis_skipped_short_lead"]) /
stats["total_original_wikis"])
wiki_ref_coverage = [
float(found) / orig for found, orig
in zip(stats["wiki_found_refs"], stats["wiki_original_refs"]) if found
]
coverage_bounds = np.arange(21).astype(np.float32) / 20
coverage_counts, coverage_bounds = np.histogram(wiki_ref_coverage,
coverage_bounds)
coverage_dist = coverage_counts.astype(np.float32) / coverage_counts.sum()
agg_stats = dict(
total_original_wikis=stats["total_original_wikis"],
total_original_refs=stats["total_original_refs"],
wiki_coverage=wiki_coverage,
wikis_skipped_no_ref=wikis_skipped_no_ref,
wikis_skipped_no_lead=wikis_skipped_no_lead,
overall_ref_coverage=ref_coverage,
per_wiki_ref_coverage_dist=list((coverage_dist * 100).astype(int)),
per_wiki_ref_coverage_bounds=list((coverage_bounds * 100).astype(int)),
ref_len_dist=list((len_dist * 100).astype(int)),
ref_len_bounds=list(len_bounds),
)
return agg_stats
def filename_to_task_id(fname):
"""Map filename to the task id that created it assuming 1k tasks."""
# This matches the order and size in WikisumBase.out_filepaths
fname = os.path.basename(fname)
shard_id_increment = {
"train": 0,
"dev": 800,
"test": 900,
}
parts = fname.split("-")
split = parts[1]
shard_id = parts[2]
task_id = int(shard_id) + shard_id_increment[split]
return task_id
def get_length(fname):
return tf.gfile.Stat(fname).length
def validate_data_files(problem, data_files, min_size):
"""Validate presence and minimum size of files."""
# Check that all files are present
data_dir = os.path.split(data_files[0])[0]
out_filepaths = problem.out_filepaths(data_dir)
missing_filepaths = set(out_filepaths) - set(data_files)
if missing_filepaths:
tf.logging.error("Missing %d data files", len(missing_filepaths))
# Check that each file is at least 100M
too_small = []
for data_file in data_files:
length = get_length(data_file)
if length < min_size:
too_small.append(data_file)
if too_small:
tf.logging.error("%d files too small", len(too_small))
bad_files = too_small + list(missing_filepaths)
return bad_files
def main(_):
if FLAGS.for_commoncrawl:
problem = wikisum.WikisumCommoncrawl()
else:
problem = wikisum.WikisumWeb()
prefix = problem.dataset_filename()
data_files = tf.gfile.Glob(os.path.join(FLAGS.out_dir, "%s*" % prefix))
missing_files = validate_data_files(
problem, data_files,
min_size=(60 if FLAGS.for_commoncrawl else 120) * 1e6)
task_ids = [filename_to_task_id(fname) for fname in missing_files]
ids_for_flag = ",".join([str(i) for i in task_ids])
tf.logging.error("You should (re)generate %d of the data files. "
"Rerun produce_examples with --instance_ids='%s'.",
len(missing_files), ids_for_flag)
# Compute and write out aggregated stats
stats_files = tf.gfile.Glob(os.path.join(FLAGS.out_dir, "stats*"))
agg_stats = aggregate_stats(stats_files)
if not FLAGS.for_commoncrawl:
coverage = agg_stats["overall_ref_coverage"] * 100
if not coverage > 80:
tf.logging.error("Overall reference coverage is expected to be > 80%. "
"It is %0.1f. You may want to rerun get_references_web.",
coverage)
with tf.gfile.Open(
os.path.join(FLAGS.out_dir, "stats.json"), "w") as f:
f.write(json.dumps(agg_stats))
if FLAGS.rm_per_shard_stats and not missing_files:
for fname in stats_files:
tf.gfile.Remove(fname)
if __name__ == "__main__":
tf.app.run()
================================================
FILE: tensor2tensor/data_generators/wikisum/wikisum.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia Summarization Problems."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import math
import os
import re
import string
import tempfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import tokenizer
from tensor2tensor.data_generators.wikisum import utils as cc_utils
from tensor2tensor.layers import modalities
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
PROCESS_FOLDER_PREFIX = "process"
REF_SHARD_FILE_PREFIX = "references.tfrecords.gz"
REF_SHARD_FILE = REF_SHARD_FILE_PREFIX + "-%05d-of-01000"
# Support files
BASE_SUPPORT_DIR = "gs://tensor2tensor-data/wikisum"
WIKI_CONTENT_DIR = os.path.join(BASE_SUPPORT_DIR, "wiki_content")
WIKI_URLS_DIR = os.path.join(BASE_SUPPORT_DIR, "wiki_urls")
WET_METADATA_DIR = os.path.join(BASE_SUPPORT_DIR, "commoncrawl_metadata")
WIKI_CONTENT_FILE = "wiki_content.tfrecords-%05d-of-01000"
WIKI_URLS_FILE = "wiki_urls.json-%05d-of-01000"
EOT = "" # end-of-title string
_MIN_REFS = 1
_MIN_LEADSECTION_TOKENS = 1
class WikisumBase(problem.Problem):
"""Base class for Wikisum problems."""
def example_reading_spec(self):
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64),
"section_boundaries": tf.VarLenFeature(tf.int64),
}
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)
@property
def target_vocab_size(self):
return 2**15
@property
def vocab_filename(self):
return "vocab.%s.%d" % (self.dataset_filename(), self.target_vocab_size)
def feature_encoders(self, data_dir):
vocab_filename = os.path.join(data_dir, self.vocab_filename)
encoder = text_encoder.SubwordTextEncoder(vocab_filename)
# Shared encoder for inputs and targets
return {"inputs": encoder, "targets": encoder}
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.stop_at_eos = True
p.vocab_size = {
"inputs": self._encoders["inputs"].vocab_size,
"targets": self._encoders["targets"].vocab_size,
}
p.modality = {
"inputs": modalities.ModalityType.SYMBOL,
"targets": modalities.ModalityType.SYMBOL,
}
def eval_metrics(self):
return super(WikisumBase, self).eval_metrics() + [
metrics.Metrics.ROUGE_2_F, metrics.Metrics.ROUGE_L_F
]
def generate_lines_for_vocab(self, wikis_dir, refs_dir, max_chars=10**7):
total_chars = 0
ref_files_by_shard = _references_files_by_shard(refs_dir)
for shard_id in range(cc_utils.NUM_SHARDS):
# Wikipedia articles
for wiki in _wiki_articles(shard_id, wikis_dir):
yield _normalize_text(wiki.title) + EOT
for section in wiki.sections:
yield _format_title(_normalize_text(section.title))
yield _normalize_text(section.text)
total_chars += len(section.title)
total_chars += len(section.text)
# References
for i, content in enumerate(
six.itervalues(_references_content(ref_files_by_shard[shard_id]))):
for line in content.split("\n"):
if line:
yield _normalize_text(line)
total_chars += len(line)
# Make sure we use at least 1k references
if i >= 1000 and total_chars >= max_chars:
break
if total_chars >= max_chars:
tf.logging.info("Seen enough chars: %d; finished.", max_chars)
break
tf.logging.info("Built vocabulary using %d chars", total_chars)
def generate_vocab(self, data_dir, wikis_dir, refs_dir):
# Produce a SubwordTextEncoder from a subset of the data
return generator_utils.get_or_generate_vocab_inner(
data_dir, self.vocab_filename, self.target_vocab_size,
self.generate_lines_for_vocab(wikis_dir, refs_dir))
def generate_data(self, data_dir, tmp_dir, task_id=-1):
tf.logging.warn("See wikisum/README.md for instructions to generate data.")
def out_filepaths(self, data_dir):
train_shards = 800
dev_shards = 100
test_shards = 100
train_filepaths = self.training_filepaths(
data_dir, train_shards, shuffled=True)
dev_filepaths = self.dev_filepaths(data_dir, dev_shards, shuffled=True)
test_filepaths = self.test_filepaths(data_dir, test_shards, shuffled=True)
out_filepaths = train_filepaths + dev_filepaths + test_filepaths
out_filepaths.sort()
assert len(out_filepaths) == cc_utils.NUM_SHARDS
return out_filepaths
@registry.register_problem
class WikisumCommoncrawl(WikisumBase):
"""Wikipedia references->article summarization task based on CommonCrawl."""
pass
@registry.register_problem
class WikisumWeb(WikisumBase):
"""Wikipedia references->article summarization task based on web data."""
pass
@registry.register_problem
class WikisumCommoncrawlLeadSection(WikisumCommoncrawl):
"""Wikipedia references->lead section summarization task."""
def preprocess_example(self, example, mode, hparams):
example["targets"] = _truncate_to_lead_section(example)
return super(WikisumCommoncrawlLeadSection, self).preprocess_example(
example, mode, hparams)
def dataset_filename(self):
return WikisumCommoncrawl.name
def generate_data(self, data_dir, tmp_dir, task_id=-1):
tf.logging.warn("Problem %s reuses data from problem %s", self.name,
WikisumCommoncrawl.name)
@registry.register_problem
class WikisumWebLeadSection(WikisumWeb):
"""Wikipedia references->lead section summarization task."""
def preprocess_example(self, example, mode, hparams):
example["targets"] = _truncate_to_lead_section(example)
return super(WikisumWebLeadSection, self).preprocess_example(
example, mode, hparams)
def dataset_filename(self):
return WikisumWeb.name
def generate_data(self, data_dir, tmp_dir, task_id=-1):
tf.logging.warn("Problem %s reuses data from problem %s", self.name,
WikisumWeb.name)
def make_ref_shard_files(out_dir):
tf.gfile.MakeDirs(out_dir)
opts = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
files = [
tf.python_io.TFRecordWriter(
os.path.join(out_dir, REF_SHARD_FILE % i), opts)
for i in range(cc_utils.NUM_SHARDS)
]
return files
def _truncate_to_lead_section(example):
wiki = example["targets"]
lead_boundary = example["section_boundaries"][0]
# Concat a new EOS to the lead since the original one gets truncated.
lead = tf.concat((wiki[:lead_boundary], [text_encoder.EOS_ID]), 0)
return lead
def _make_example_from_record(record):
features = {
"url":
tf.train.Feature(bytes_list=tf.train.BytesList(value=[record.url])),
"content":
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[record.content])),
}
return tf.train.Example(features=tf.train.Features(feature=features))
def _shard_id_for_file(sharded_filename):
suffix = "00000-of-00000"
parts = sharded_filename[-len(suffix):].split("-")
assert len(parts) == 3
return int(parts[0])
def _references_files_by_shard(refs_dir):
process_dirs = _process_folders(refs_dir)
shards = collections.defaultdict(list)
for d in process_dirs:
ref_files = tf.gfile.Glob(os.path.join(d, REF_SHARD_FILE_PREFIX) + "*")
for f in ref_files:
shards[_shard_id_for_file(f)].append(f)
return shards
def _references_content(ref_files):
"""Returns dict."""
example_spec = {
"url": tf.FixedLenFeature([], tf.string),
"content": tf.FixedLenFeature([], tf.string),
}
data = {}
for ex in generator_utils.tfrecord_iterator(
ref_files, gzipped=True, example_spec=example_spec):
data[ex["url"]] = text_encoder.to_unicode(ex["content"])
return data
def _wiki_urls_for_shard(shard_id, urls_dir=None):
"""Urls for chunk: dict ref_urls>."""
urls_dir = urls_dir or WIKI_URLS_DIR
urls_filepath = os.path.join(urls_dir, WIKI_URLS_FILE % shard_id)
with tf.gfile.GFile(urls_filepath) as f:
return json.loads(f.read())
class WikipediaSection(
collections.namedtuple("WikipediaSection", ["title", "text"])):
pass
class WikipediaArticle(
collections.namedtuple("WikipediaArticle", ["url", "title", "sections"])):
pass
def _wiki_articles(shard_id, wikis_dir=None):
"""Generates WikipediaArticles from GCS that are part of shard shard_id."""
if not wikis_dir:
wikis_dir = WIKI_CONTENT_DIR
with tf.Graph().as_default():
dataset = tf.data.TFRecordDataset(
cc_utils.readahead(
os.path.join(wikis_dir, WIKI_CONTENT_FILE % shard_id)),
buffer_size=16 * 1000 * 1000)
def _parse_example(ex_ser):
"""Parse serialized Example containing Wikipedia article content."""
features = {
"url": tf.VarLenFeature(tf.string),
"title": tf.VarLenFeature(tf.string),
"section_titles": tf.VarLenFeature(tf.string),
"section_texts": tf.VarLenFeature(tf.string),
}
ex = tf.parse_single_example(ex_ser, features)
for k in ex.keys():
ex[k] = ex[k].values
ex["url"] = ex["url"][0]
ex["title"] = ex["title"][0]
return ex
dataset = dataset.map(_parse_example, num_parallel_calls=32)
dataset = dataset.prefetch(100)
record_it = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
ex = sess.run(record_it)
except tf.errors.OutOfRangeError:
break
sections = [
WikipediaSection(title=text_encoder.to_unicode(title),
text=text_encoder.to_unicode(text))
for title, text in zip(ex["section_titles"], ex["section_texts"])
]
yield WikipediaArticle(
url=text_encoder.to_unicode(ex["url"]),
title=text_encoder.to_unicode(ex["title"]),
sections=sections)
def _token_counts(text, token_set=None):
counts = collections.defaultdict(int)
for token in tokenizer.encode(text_encoder.native_to_unicode(text)):
if token_set and token not in token_set:
continue
counts[token] += 1
return counts
def _normalize_text(text):
text = text.lower()
# Space around punctuation
text = re.sub("[%s]" % re.escape(string.punctuation), r" \g<0> ", text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def _tokens_to_score(tokens):
return {t for t in tokens if re.search("[a-z0-9]", t)}
def rank_reference_paragraphs(wiki_title, references_content, normalize=True):
"""Rank and return reference paragraphs by tf-idf score on title tokens."""
normalized_title = _normalize_text(wiki_title)
title_tokens = _tokens_to_score(
set(tokenizer.encode(text_encoder.native_to_unicode(normalized_title))))
ref_paragraph_info = []
doc_counts = collections.defaultdict(int)
for ref in references_content:
for paragraph in ref.split("\n"):
normalized_paragraph = _normalize_text(paragraph)
if cc_utils.filter_paragraph(normalized_paragraph):
# Skip paragraph
continue
counts = _token_counts(normalized_paragraph, title_tokens)
for token in title_tokens:
if counts[token]:
doc_counts[token] += 1
content = normalized_paragraph if normalize else paragraph
info = {"content": content, "counts": counts}
ref_paragraph_info.append(info)
for info in ref_paragraph_info:
score = 0.
for token in title_tokens:
term_frequency = info["counts"][token]
inv_doc_frequency = (
float(len(ref_paragraph_info)) / max(doc_counts[token], 1))
score += term_frequency * math.log(inv_doc_frequency)
info["score"] = score
ref_paragraph_info.sort(key=lambda el: el["score"], reverse=True)
return [info["content"] for info in ref_paragraph_info]
def produce_examples(shard_ids, wikis_dir, refs_dir, urls_dir, vocab_path,
out_filepaths):
"""Produce examples from shard_ids to out_filepaths."""
# * Join the Wikipedia articles with their references
# * Run Tf-idf to sort reference paragraphs
# * Encode the Wikipedia and reference text with the vocabulary
# * Write out TFRecords of tensorflow.Example
tf.logging.info("Processing %d input shards into %d output files.",
len(shard_ids), len(out_filepaths))
vocab = text_encoder.SubwordTextEncoder(vocab_path)
eot_ids = vocab.encode(EOT)
def example_generator():
"""Generate Example dicts."""
stats = dict(total_original_wikis=0, total_original_refs=0,
total_found_refs=0, ref_lengths=[], wiki_original_refs=[],
wiki_found_refs=[], wikis_skipped_no_refs=0,
wikis_skipped_short_lead=0, num_wikis_written=0)
ref_files_by_shard = _references_files_by_shard(refs_dir)
for shard_id in shard_ids:
tf.logging.info("Processing shard %d", shard_id)
wiki_urls = _wiki_urls_for_shard(shard_id, urls_dir)
tf.logging.info("Loaded wiki URLs for shard")
refs_content = _references_content(ref_files_by_shard[shard_id])
tf.logging.info("Loaded reference content for shard")
for i, wiki in enumerate(_wiki_articles(shard_id, wikis_dir)):
if not i % 1000:
tf.logging.info("Processing wiki index %d for shard %d", i, shard_id)
stats["total_original_wikis"] += 1
# Get reference content
wiki_ref_content = []
ref_urls = wiki_urls[wiki.url]["refs"]
stats["total_original_refs"] += len(ref_urls)
stats_wiki_original_refs = len(ref_urls)
stats_wiki_found_refs = 0
for ref_url in ref_urls:
ref_content = refs_content.get(ref_url)
if not ref_content:
continue
stats["total_found_refs"] += 1
stats["ref_lengths"].append(len(ref_content))
stats_wiki_found_refs += 1
wiki_ref_content.append(ref_content)
stats["wiki_original_refs"].append(stats_wiki_original_refs)
stats["wiki_found_refs"].append(stats_wiki_found_refs)
if not wiki_ref_content or len(wiki_ref_content) < _MIN_REFS:
# No/few refs were found
stats["wikis_skipped_no_refs"] += 1
continue
# Rank reference paragraphs with TFIDF
wiki_title = _normalize_text(wiki.title)
ranked_paragraphs = rank_reference_paragraphs(wiki_title,
wiki_ref_content)
# Construct inputs from Wiki title and references
inputs = []
inputs.extend(vocab.encode(wiki_title))
inputs.extend(eot_ids)
for paragraph in ranked_paragraphs:
if len(inputs) >= 1e6:
break
paragraph += " "
inputs.extend(vocab.encode(paragraph))
# Construct targets from article sections
targets, section_boundaries = _encode_wiki_sections(
wiki.sections, vocab)
# Skip if lead section is too short
if (not section_boundaries or
section_boundaries[0] < _MIN_LEADSECTION_TOKENS):
stats["wikis_skipped_short_lead"] += 1
continue
inputs.append(text_encoder.EOS_ID)
targets.append(text_encoder.EOS_ID)
stats["num_wikis_written"] += 1
yield {
"inputs": inputs,
"targets": targets,
"section_boundaries": section_boundaries,
}
tf.logging.info("Total: %d, Skipped: %d",
stats["num_wikis_written"],
stats["total_original_wikis"] - stats["num_wikis_written"])
tf.logging.info("Total refs: %d, Skipped refs: %d",
stats["total_found_refs"],
stats["total_original_refs"] - stats["total_found_refs"])
stats_fname = os.path.join(os.path.split(out_filepaths[0])[0],
"stats.%d.json" % shard_ids[0])
with tf.gfile.Open(stats_fname, "w") as f:
f.write(json.dumps(stats))
generator_utils.generate_files(example_generator(), out_filepaths)
def _format_title(title):
return " == %s == " % title
def _encode_wiki_sections(sections, vocab):
"""Encodes sections with vocab. Returns ids and section boundaries."""
ids = []
section_boundaries = []
for i, section in enumerate(sections):
if i > 0:
# Skip including article title
ids.extend(vocab.encode(_format_title(_normalize_text(section.title))))
ids.extend(vocab.encode(_normalize_text(section.text)))
section_boundaries.append(len(ids))
return ids, section_boundaries
def _process_folders(tmp_dir):
return tf.gfile.Glob(os.path.join(tmp_dir, PROCESS_FOLDER_PREFIX) + "*")
def extract_references_from_wets(wet_files, metadata_dir, out_dir,
tmp_dir=None):
"""Extract references from WET files into sharded output files."""
# Setup output files
shard_files = make_ref_shard_files(out_dir)
num_refs = 0
for i, wet_file in enumerate(wet_files):
num_refs_in_wet = 0
tf.logging.info("Processing file %d", i)
# Read metadata file
metadata_fname = os.path.join(
metadata_dir, os.path.basename(wet_file)) + cc_utils.METADTA_SUFFIX
with tf.gfile.Open(cc_utils.readahead(metadata_fname)) as f:
wet_metadata = json.loads(f.read())
if not wet_metadata:
# No references in this WET file
continue
if wet_file.startswith("http"):
# download
if not tmp_dir:
tmp_dir = tempfile.gettempdir()
record_gen = cc_utils.wet_records_from_url(wet_file, tmp_dir)
else:
# local
record_gen = cc_utils.wet_records_from_file_obj(
cc_utils.gzip_memfile(wet_file), take_ownership=True)
for wet_record in record_gen:
shard_ids = wet_metadata.get(wet_record.url)
if not shard_ids:
# URL not in dataset
continue
# Serialize and write out
ex = _make_example_from_record(wet_record)
ex_str = ex.SerializeToString()
for shard_id in shard_ids:
shard_files[shard_id].write(ex_str)
num_refs += 1
num_refs_in_wet += 1
tf.logging.info("Wrote out %d references for this WET", num_refs_in_wet)
tf.logging.info("Wrote out %d references total", num_refs)
# Cleanup
for shard_file in shard_files:
shard_file.close()
================================================
FILE: tensor2tensor/data_generators/wikitext103.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data generators for wikitext-103.
Wikitext-103: Long term dependency language modeling dataset
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import zipfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
def _build_vocab(filename, vocab_dir, vocab_name):
"""Reads a file to build a vocabulary.
Args:
filename: file to read list of words from.
vocab_dir: directory where to save the vocabulary.
vocab_name: vocab file name.
Returns:
text encoder.
"""
vocab_path = os.path.join(vocab_dir, vocab_name)
if not tf.gfile.Exists(vocab_path):
with tf.gfile.GFile(filename, "r") as f:
data = f.read().split()
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
encoder = text_encoder.TokenTextEncoder(None, vocab_list=words)
encoder.store_to_file(vocab_path)
else:
encoder = text_encoder.TokenTextEncoder(vocab_path)
return encoder
def _maybe_download_corpus(tmp_dir, vocab_type):
"""Download and unpack the corpus.
Args:
tmp_dir: directory containing dataset.
vocab_type: which vocabulary are we using.
Returns:
The list of names of files.
"""
if vocab_type == text_problems.VocabType.CHARACTER:
dataset_url = ("https://s3.amazonaws.com/research.metamind.io/wikitext"
"/wikitext-103-raw-v1.zip")
dir_name = "wikitext-103-raw"
else:
dataset_url = ("https://s3.amazonaws.com/research.metamind.io/wikitext"
"/wikitext-103-v1.zip")
dir_name = "wikitext-103"
fname = os.path.basename(dataset_url)
compressed_filepath = generator_utils.maybe_download(tmp_dir, fname,
dataset_url)
zip_ref = zipfile.ZipFile(compressed_filepath, "r")
zip_ref.extractall(tmp_dir)
zip_ref.close()
files = os.path.join(tmp_dir, dir_name, "*")
train_file, valid_file, test_file = None, None, None
for f in tf.gfile.Glob(files):
fname = os.path.basename(f)
if "train" in fname:
train_file = f
elif "valid" in fname:
valid_file = f
elif "test" in fname:
test_file = f
assert train_file, "Training file not found"
assert valid_file, "Validation file not found"
assert test_file, "Testing file not found"
return train_file, valid_file, test_file
@registry.register_problem
class LanguagemodelWikitext103(text_problems.Text2SelfProblem):
"""Wikitext103 dataset token-level."""
@property
def dataset_splits(self):
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 10,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}, {
"split": problem.DatasetSplit.TEST,
"shards": 1,
}]
@property
def is_generate_per_split(self):
return True
@property
def vocab_type(self):
return text_problems.VocabType.TOKEN
def generate_samples(self, data_dir, tmp_dir, dataset_split):
train_file, valid_file, test_file = _maybe_download_corpus(
tmp_dir, self.vocab_type)
if dataset_split == problem.DatasetSplit.TRAIN:
filepath = train_file
if self.vocab_type == text_problems.VocabType.TOKEN:
_build_vocab(train_file, data_dir, self.vocab_filename)
elif dataset_split == problem.DatasetSplit.EVAL:
filepath = valid_file
elif dataset_split == problem.DatasetSplit.TEST:
filepath = test_file
def _generate_samples():
with tf.gfile.GFile(filepath, "r") as f:
for line in f:
line = " ".join(line.strip().split())
if line:
yield {"targets": line}
return _generate_samples()
@registry.register_problem
class LanguagemodelWikitext103Characters(LanguagemodelWikitext103):
"""Wikitext-103, character-level."""
@property
def vocab_type(self):
return text_problems.VocabType.CHARACTER
@registry.register_problem
class LanguagemodelWikitext103L4k(LanguagemodelWikitext103):
"""Wikitext-103, token-level, with examples up to 4,096 tokens long."""
def generate_samples(self, data_dir, tmp_dir, dataset_split):
samples_by_line = super(LanguagemodelWikitext103L4k,
self).generate_samples(data_dir, tmp_dir,
dataset_split)
def _generate_samples():
tokens = []
for sample in samples_by_line:
sample_tokens = sample["targets"].split()
if len(tokens) + len(sample_tokens) < self.sequence_length:
tokens.extend(sample_tokens)
else:
yield {"targets": " ".join(tokens)}
tokens = sample_tokens
return _generate_samples()
def max_length(self, model_hparams):
return model_hparams.split_to_length or self.sequence_length
@property
def sequence_length(self):
"""Length of each example (in tokens)."""
return 4096
@registry.register_problem
class LanguagemodelWikitext103L16k(LanguagemodelWikitext103L4k):
"""Wikitext-103, token-level, with examples up to 16,384 tokens long."""
@property
def sequence_length(self):
"""Length of each example (in tokens)."""
return 16384
================================================
FILE: tensor2tensor/data_generators/wnli.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data generators for the Winograd NLI dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import zipfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
EOS = text_encoder.EOS
@registry.register_problem
class WinogradNLI(text_problems.TextConcat2ClassProblem):
"""Winograd NLI classification problems."""
# Link to data from GLUE: https://gluebenchmark.com/tasks
_WNLI_URL = ("https://firebasestorage.googleapis.com/v0/b/"
"mtl-sentence-representations.appspot.com/o/"
"data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-"
"4bd7-99a5-5e00222e0faf")
@property
def is_generate_per_split(self):
return True
@property
def dataset_splits(self):
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 1,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
@property
def approx_vocab_size(self):
return 2**13 # 8k vocab suffices for this small dataset.
@property
def vocab_filename(self):
return "vocab.wnli.%d" % self.approx_vocab_size
@property
def num_classes(self):
return 2
def class_labels(self, data_dir):
del data_dir
# Note this binary classification is different from usual MNLI.
return ["not_entailment", "entailment"]
def _maybe_download_corpora(self, tmp_dir):
wnli_filename = "WNLI.zip"
wnli_finalpath = os.path.join(tmp_dir, "WNLI")
if not tf.gfile.Exists(wnli_finalpath):
zip_filepath = generator_utils.maybe_download(
tmp_dir, wnli_filename, self._WNLI_URL)
zip_ref = zipfile.ZipFile(zip_filepath, "r")
zip_ref.extractall(tmp_dir)
zip_ref.close()
return wnli_finalpath
def example_generator(self, filename):
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
line = text_encoder.to_unicode_utf8(line.strip())
_, s1, s2, l = line.split("\t")
inputs = [s1, s2]
yield {
"inputs": inputs,
"label": int(l)
}
def generate_samples(self, data_dir, tmp_dir, dataset_split):
wnli_dir = self._maybe_download_corpora(tmp_dir)
if dataset_split == problem.DatasetSplit.TRAIN:
filesplit = "train.tsv"
else:
filesplit = "dev.tsv"
filename = os.path.join(wnli_dir, filesplit)
for example in self.example_generator(filename):
yield example
@registry.register_problem
class WinogradNLICharacters(WinogradNLI):
"""Winograd NLI classification problems, character level"""
@property
def vocab_type(self):
return text_problems.VocabType.CHARACTER
def global_task_id(self):
return problem.TaskID.EN_NLI
================================================
FILE: tensor2tensor/data_generators/wsj_parsing.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data generators for parsing data-sets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
flags.DEFINE_string("parsing_path", "", "Path to parsing files in tmp_dir.")
FLAGS = flags.FLAGS
@registry.register_problem
class WsjParsing(text_problems.Text2textTmpdir):
"""Generate vocabulary and training data for parsing.
"""
# These files are used for vocab generation
TRAIN_FILES = ("wsj.train.text.txt", "wsj.train.tags.txt")
# These files are used for generating encoded samples
TRAIN_FILES_TREE = "wsjTrain.trees"
EVAL_FILES_TREE = "wsjEval.trees"
def generate_samples(self, data_dir, tmp_dir, dataset_split):
del data_dir
is_training = dataset_split == problem.DatasetSplit.TRAIN
tree_file = self.TRAIN_FILES_TREE if is_training else self.EVAL_FILES_TREE
tree_file_path = os.path.join(tmp_dir, tree_file)
with tf.gfile.GFile(tree_file_path, mode="r") as cur_tree_file:
for line in cur_tree_file:
(words, tags) = words_and_tags_from_wsj_tree(line)
yield {"inputs": words, "targets": tags}
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
return text_problems.text2text_generate_encoded(generator, encoder,
has_inputs=self.has_inputs)
def generate_text_for_vocab(self, data_dir, tmp_dir):
files = [os.path.join(tmp_dir, f) for f in self.TRAIN_FILES]
inputs_file, targets_file = files
for i, sample in enumerate(text_problems.text2text_txt_iterator(inputs_file,
targets_file
)):
yield sample["inputs"]
yield sample["targets"]
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab:
break
@property
def max_samples_for_vocab(self):
return 1000
def words_and_tags_from_wsj_tree(tree_string):
"""Generates linearized trees and tokens from the wsj tree format.
It uses the linearized algorithm described in https://arxiv.org/abs/1412.7449.
Args:
tree_string: tree in wsj format
Returns:
tuple: (words, linearized tree)
"""
stack, tags, words = [], [], []
for tok in tree_string.strip().split():
if tok[0] == "(":
symbol = tok[1:]
tags.append(symbol)
stack.append(symbol)
else:
assert tok[-1] == ")"
stack.pop() # Pop the POS-tag.
while tok[-2] == ")":
tags.append("/" + stack.pop())
tok = tok[:-1]
words.append(tok[:-1])
return str.join(" ", words), str.join(" ", tags[1:-1]) # Strip "TOP" tag.
def token_generator(tree_path, source_token_vocab, target_token_vocab,
eos=None):
"""Generator for parsing as a sequence-to-sequence task that uses tokens.
This generator assumes the files at source_path and target_path have
the same number of lines and yields dictionaries of "inputs" and "targets"
where inputs and targets are token ids from source and target lines
converted to integers using the token_map.
Args:
tree_path: path to the file with WSJ format trees, one per line.
source_token_vocab: GenericVocabulary object for source vocabulary.
target_token_vocab: GenericVocabulary object for target vocabulary.
eos: integer to append at the end of each sequence (default: None).
Yields:
A dictionary {"inputs": source-line, "targets": target-line} where
the lines are integer lists converted from tokens in the file lines.
"""
eos_list = [] if eos is None else [eos]
with tf.gfile.GFile(tree_path, mode="r") as tree_file:
tree_line = tree_file.readline()
while tree_line:
source, target = words_and_tags_from_wsj_tree(tree_line)
source_ints = source_token_vocab.encode(source.strip()) + eos_list
target_ints = target_token_vocab.encode(target.strip()) + eos_list
yield {"inputs": source_ints, "targets": target_ints}
tree_line = tree_file.readline()
def parsing_token_generator(data_dir, tmp_dir, train, source_vocab_size,
target_vocab_size):
"""Generator for parsing as a sequence-to-sequence task that uses tokens.
This generator assumes the files parsing_{train,dev}.trees, which contain
trees in WSJ format.
Args:
data_dir: path to the data directory.
tmp_dir: path to temporary storage directory.
train: whether we're training or not.
source_vocab_size: source vocab size.
target_vocab_size: target vocab size.
Returns:
A generator to a dictionary of inputs and outputs.
"""
# TODO(lukaszkaiser): Correct these calls to generate vocabularies. No data
# sources are being passed.
del (data_dir, tmp_dir, train, source_vocab_size, target_vocab_size)
assert False, "Vocabulary generation not implemented"
# source_symbolizer_vocab = generator_utils.get_or_generate_vocab(
# data_dir, tmp_dir, "wsj_source.vocab.%d" % source_vocab_size,
# source_vocab_size)
# target_symbolizer_vocab = generator_utils.get_or_generate_vocab(
# data_dir, tmp_dir, "wsj_target.vocab.%d" % target_vocab_size,
# target_vocab_size)
# filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev")
# tree_filepath = os.path.join(tmp_dir, filename)
# return token_generator(tree_filepath, source_symbolizer_vocab,
# target_symbolizer_vocab, 1)
================================================
FILE: tensor2tensor/data_generators/yelp_full.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Yelp dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
@registry.register_problem
class SentimentYelpFull(text_problems.Text2ClassProblem):
"""Yelp dataset."""
URL = "https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz"
@property
def is_generate_per_split(self):
return True
@property
def dataset_splits(self):
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 10,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
@property
def approx_vocab_size(self):
return 2**13 # 8k vocab suffices for this small dataset.
@property
def num_classes(self):
return 5
def class_labels(self, data_dir):
del data_dir
return ["1", "2", "3", "4", "5"]
def doc_generator(self, yelp_dir, dataset, include_label=False):
file_path = os.path.join(yelp_dir, dataset + ".csv")
with tf.gfile.Open(file_path) as yelp_f:
lines = yelp_f.readlines()
for line in lines:
label = line[1]
doc = line[5:-2].strip()
if include_label:
yield doc, label
else:
yield doc
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate examples."""
# Download and extract
compressed_filename = os.path.basename(self.URL)
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
self.URL)
yelp_dir = os.path.join(tmp_dir, "yelp_review_full_csv")
if not tf.gfile.Exists(yelp_dir):
with tarfile.open(download_path, "r:gz") as tar:
tar.extractall(tmp_dir)
# Generate examples
train = dataset_split == problem.DatasetSplit.TRAIN
dataset = "train" if train else "test"
for doc, label in self.doc_generator(yelp_dir, dataset, include_label=True):
yield {
"inputs": doc,
"label": int(label),
}
@registry.register_problem
class SentimentYelpFullCharacters(SentimentYelpFull):
"""Yelp dataset, character level."""
@property
def vocab_type(self):
return text_problems.VocabType.CHARACTER
def global_task_id(self):
return problem.TaskID.EN_CHR_SENT
================================================
FILE: tensor2tensor/data_generators/yelp_polarity.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Yelp dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
@registry.register_problem
class SentimentYelpPolarity(text_problems.Text2ClassProblem):
"""Yelp dataset."""
URL = "https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz"
@property
def is_generate_per_split(self):
return True
@property
def dataset_splits(self):
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 10,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
@property
def approx_vocab_size(self):
return 2**13 # 8k vocab suffices for this small dataset.
@property
def num_classes(self):
return 2
def class_labels(self, data_dir):
del data_dir
return ["1", "2"]
def doc_generator(self, yelp_dir, dataset, include_label=False):
file_path = os.path.join(yelp_dir, dataset + ".csv")
with tf.gfile.Open(file_path) as yelp_f:
lines = yelp_f.readlines()
for line in lines:
label = line[1]
doc = line[5:-2].strip()
if include_label:
yield doc, label
else:
yield doc
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate examples."""
# Download and extract
compressed_filename = os.path.basename(self.URL)
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
self.URL)
yelp_dir = os.path.join(tmp_dir, "yelp_review_polarity_csv")
if not tf.gfile.Exists(yelp_dir):
with tarfile.open(download_path, "r:gz") as tar:
tar.extractall(tmp_dir)
# Generate examples
train = dataset_split == problem.DatasetSplit.TRAIN
dataset = "train" if train else "test"
for doc, label in self.doc_generator(yelp_dir, dataset, include_label=True):
yield {
"inputs": doc,
"label": int(label),
}
@registry.register_problem
class SentimentYelpPolarityCharacters(SentimentYelpPolarity):
"""Yelp dataset, character level."""
@property
def vocab_type(self):
return text_problems.VocabType.CHARACTER
def global_task_id(self):
return problem.TaskID.EN_CHR_SENT
================================================
FILE: tensor2tensor/envs/__init__.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Environments defined in T2T. Imports here force registration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.envs import gym_env_problem
from tensor2tensor.envs import tic_tac_toe_env
from tensor2tensor.envs import tic_tac_toe_env_problem
================================================
FILE: tensor2tensor/envs/env_problem.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for envs that store their history.
EnvProblem subclasses Problem and also implements the Gym interface (step,
reset, render, close, seed)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from gym.core import Env
import numpy as np
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.envs import gym_spaces_utils
from tensor2tensor.envs import trajectory
from tensor2tensor.layers import modalities
from tensor2tensor.utils import contrib
import tensorflow.compat.v1 as tf
# Names for data fields in stored tf.Examples.
TIMESTEP_FIELD = "timestep"
ACTION_FIELD = "action"
RAW_REWARD_FIELD = "raw_reward"
PROCESSED_REWARD_FIELD = "reward"
DONE_FIELD = "done"
OBSERVATION_FIELD = "observation"
class EnvProblem(Env, problem.Problem):
"""Base class of an env which generates data like a problem class.
EnvProblem is both a gym Env and a Problem, since it subclasses both.
Conceptually it contains `batch_size` environments on which step (and reset)
are called. The data that is generated by the repeated application of step and
reset is stored within this class and is persisted on disk when we call
`generate_data` on it.
Subclasses *should* override the following functions:
- initialize_environments
- observation_space
- action_space
- reward_range
- _reset
- _step
- _render
In addition, they should ovveride the following functions, which are used in
the `hparams` function to return modalities and vocab_sizes.
- input_modality
- input_vocab_size
- target_modality
- target_vocab_size
- action_modality
- reward_modality
NON NATIVELY BATCHED ENVS:
The implementation for cases where the env is not batched by default is
`gym_env_problem.GymEnvProblem`.
NATIVELY BATCHED ENVS:
If however, our env is a neural network, which can be batched by default, we
should:
# 1 - Give it a gym style interface, by overriding observation_space and
action_space.
# 2 - Override `_reset` and `_step` to do the reset and step in a natively
batched manner.
# 3 - More generally any function that iterates over the self._env list will
need to be overridden, ex: `_verify_same_spaces` and `initialize_environments`
KNOWN LIMITATIONS:
- observation_space and action_space should be subclasses of gym.spaces
- not all subclasses of gym.spaces are supported
"""
def __init__(self,
batch_size=None,
discrete_rewards=True,
parallelism=1,
**env_kwargs):
"""Initializes this class by creating the envs and managing trajectories.
Args:
batch_size: (int or None) How many envs to make in the non natively
batched mode.
discrete_rewards: (bool) whether to round the rewards to the nearest
integer.
parallelism: (int) If this is greater than one then we run the envs in
parallel using multi-threading.
**env_kwargs: (dict) Additional kwargs to pass to the environments.
"""
# Call the super's ctor.
problem.Problem.__init__(self, was_reversed=False, was_copy=False)
# An env generates data when it is given actions by an agent which is either
# a policy or a human -- this is supposed to be the `id` of the agent.
#
# In practice, this is used only to store (and possibly retrieve) history
# to an appropriate directory.
self._agent_id = "default"
# If set, we discretize the rewards and treat them as integers.
self._discrete_rewards = discrete_rewards
# A data structure to hold the `batch_size` currently active trajectories
# and also the ones that are completed, i.e. done.
self._trajectories = None
self._batch_size = None
self._parallelism = None
# The parallelism is passes in via env_kwargs because it will be used by
# `GymEnvProblem` to paralellize env actions across a batch.
env_kwargs["parallelism"] = parallelism
if batch_size is not None:
self.initialize(batch_size=batch_size, **env_kwargs)
@property
def batch_size(self):
# TODO(afrozm): I've added this here since it is being used in a lot of
# places in ppo_learner.py -- re-evaluate if needed.
return self._batch_size
@property
def trajectories(self):
return self._trajectories
@trajectories.setter
def trajectories(self, trajectories_):
assert self.trajectories.batch_size == trajectories_.batch_size
self._trajectories = trajectories_
def initialize(self, batch_size=1, **kwargs):
self.initialize_environments(batch_size=batch_size, **kwargs)
self._batch_size = batch_size
# This data structure stores the history of each env.
#
# NOTE: Even if the env is a NN and can step in all batches concurrently, it
# is still valuable to store the trajectories separately.
self._trajectories = trajectory.BatchTrajectory(batch_size=batch_size)
# Assert that *all* the above are now set, we should do this since
# subclasses can override `initialize_environments`.
self.assert_common_preconditions()
assert self.observation_space is not None
assert self.action_space is not None
assert self.reward_range is not None
def initialize_environments(self, batch_size=1, parallelism=1, **kwargs):
"""Initializes the environments.
Args:
batch_size: (int) Number of envs to initialize.
parallelism: (int) If this is greater than one then we allow the
implementation to use multi-threading to step the envs.
**kwargs: (dict) Any additional args needed to initialize the envs.
"""
raise NotImplementedError
def assert_common_preconditions(self):
pass
@property
def observation_space(self):
raise NotImplementedError
@property
def observation_spec(self):
"""The spec for reading an observation stored in a tf.Example."""
return gym_spaces_utils.gym_space_spec(self.observation_space)
def process_observations(self, observations):
"""Processes observations prior to saving in the trajectories.
Args:
observations: (np.ndarray) observations to be processed.
Returns:
processed observation
"""
return observations
@property
def action_space(self):
raise NotImplementedError
@property
def action_spec(self):
"""The spec for reading an observation stored in a tf.Example."""
return gym_spaces_utils.gym_space_spec(self.action_space)
@property
def action_modality(self):
raise NotImplementedError
@property
def num_actions(self):
"""Returns the number of actions in a discrete action space."""
return gym_spaces_utils.cardinality(self.action_space)
@property
def reward_range(self):
# We clip rewards to this range before processing them further, as described
# in `process_rewards`.
raise NotImplementedError
@property
def is_reward_range_finite(self):
min_reward, max_reward = self.reward_range
return (min_reward != -np.inf) and (max_reward != np.inf)
@property
def discrete_rewards(self):
return self._discrete_rewards
def process_rewards(self, rewards):
"""Clips the rewards, optionally rounds them and casts to integer.
Args:
rewards: numpy array of raw (float) rewards.
Returns:
processed_rewards: numpy array of np.int64
"""
min_reward, max_reward = self.reward_range
# Clips at min and max reward.
rewards = np.clip(rewards, min_reward, max_reward)
if self._discrete_rewards:
# Round to (nearest) int and convert to integral type.
rewards = np.around(rewards, decimals=0).astype(np.int64)
return rewards
@property
def is_processed_rewards_discrete(self):
"""Returns true if `self.process_rewards` returns discrete rewards."""
# Subclasses can override, but it should match their self.process_rewards.
# This check is a little hackily.
return self.process_rewards(0.0).dtype == np.int64
@property
def num_rewards(self):
"""Returns the number of distinct rewards.
Returns:
Returns None if the reward range is infinite or the processed rewards
aren't discrete, otherwise returns the number of distinct rewards.
"""
# Pre-conditions: reward range is finite.
# : processed rewards are discrete.
if not self.is_reward_range_finite:
logging.warn("Infinite reward range, `num_rewards returning None`")
return None
if not self.is_processed_rewards_discrete:
logging.warn(
"Processed rewards are not discrete, `num_rewards` returning None")
return None
min_reward, max_reward = self.reward_range
return max_reward - min_reward + 1
@property
def input_modality(self):
raise NotImplementedError
@property
def reward_modality(self):
raise NotImplementedError
@property
def input_vocab_size(self):
raise NotImplementedError
@property
def target_modality(self):
raise NotImplementedError
@property
def target_vocab_size(self):
raise NotImplementedError
@property
def unwrapped(self):
return self
def seed(self, seed=None):
return [seed]
def close(self):
pass
def _reset(self, indices):
"""Resets environments at indices shouldn't pre-process or record.
Args:
indices: list of indices of underlying envs to call reset on.
Returns:
np.ndarray of stacked observations from the reset-ed envs.
"""
raise NotImplementedError
def truncate(self, indices=None, num_to_keep=1):
"""Truncates trajectories at the specified indices."""
if indices is None:
indices = np.arange(self.batch_size)
self.trajectories.truncate_trajectories(indices, num_to_keep=num_to_keep)
def reset(self, indices=None):
"""Resets environments at given indices.
Subclasses should override _reset to do the actual reset if something other
than the default implementation is desired.
NOTE: With `indices` as None the recorded trajectories are also erased since
the expecation is that we want to re-use the whole env class from
scratch.
Args:
indices: Indices of environments to reset. If None all envs are reset as
well as trajectories are erased.
Returns:
Batch of initial observations of reset environments.
"""
if indices is None:
self.trajectories.reset_batch_trajectories()
indices = np.arange(self.batch_size)
# If this is empty (not None) then don't do anything, no env was done.
if indices.size == 0:
logging.warning(
"`reset` called with empty indices array, this is a no-op.")
return None
# Pre-conditions: common_preconditions, see `assert_common_preconditions`.
self.assert_common_preconditions()
observations = self._reset(indices)
processed_observations = self.process_observations(observations)
# Record history.
self.trajectories.reset(indices, processed_observations)
return processed_observations
def _render(self, indices, mode="human"):
"""Renders the environments with the given mode on the specified indices.
Args:
indices: array of indices.
mode: rendering mode.
Returns:
a list of return values from the environments rendered.
"""
raise NotImplementedError
def render(self, indices=None, mode="human"):
"""Renders the environments with the given mode on the specified indices.
Args:
indices: array of indices, calls render on everything if indices is None.
mode: rendering mode.
Returns:
a list of return values from the environments rendered.
"""
if indices is None:
indices = np.arange(self.batch_size)
return self._render(indices, mode)
def _step(self, actions):
"""Takes a step in all environments, shouldn't pre-process or record.
Args:
actions: (np.ndarray) with first dimension equal to the batch size.
Returns:
a tuple of stacked raw observations, raw rewards, dones and infos.
"""
raise NotImplementedError
def step(self, actions, infos=None):
"""Takes a step in all environments.
Subclasses should override _step to do the actual reset if something other
than the default implementation is desired.
Args:
actions: Batch of actions.
infos: (optional) a dictionary of keys and values, where all the values
have the first dimension as batch_size.
Returns:
(preprocessed_observations, processed_rewards, dones, env_infos).
"""
# Pre-conditions: common_preconditions, see `assert_common_preconditions`.
# : len(actions) == len(self._envs)
self.assert_common_preconditions()
assert self.batch_size == len(actions)
observations, raw_rewards, dones, env_infos = self._step(actions)
# Process rewards.
raw_rewards = raw_rewards.astype(np.float32)
processed_rewards = self.process_rewards(raw_rewards)
# Process observations.
processed_observations = self.process_observations(observations)
# Record history.
self.trajectories.step(processed_observations, raw_rewards,
processed_rewards, dones, actions,
infos=infos)
return processed_observations, processed_rewards, dones, env_infos
def example_reading_spec(self):
"""Data fields to store on disk and their decoders."""
# Subclasses can override and/or extend.
processed_reward_type = tf.float32
if self.is_processed_rewards_discrete:
processed_reward_type = tf.int64
data_fields = {
TIMESTEP_FIELD: tf.FixedLenFeature((1,), tf.int64),
RAW_REWARD_FIELD: tf.FixedLenFeature((1,), tf.float32),
PROCESSED_REWARD_FIELD: tf.FixedLenFeature((1,), processed_reward_type),
DONE_FIELD: tf.FixedLenFeature((1,), tf.int64), # we wrote this as int.
# Special treatment because we need to determine type and shape, also
# enables classes to override.
OBSERVATION_FIELD: self.observation_spec,
ACTION_FIELD: self.action_spec,
}
data_items_to_decoders = {
field: contrib.slim().tfexample_decoder.Tensor(field)
for field in data_fields
}
return data_fields, data_items_to_decoders
def hparams(self, defaults, model_hparams):
# Usually when using the environment in a supervised setting, given the
# observation we are predicting the reward.
p = defaults
# Have to add these the 'proper' way, otherwise __str__ doesn't show them.
if "modality" not in p:
p.add_hparam("modality", {})
if "vocab_size" not in p:
p.add_hparam("vocab_size", {})
# TODO(afrozm): Document what all of these keys are and are supposed to do.
p.modality.update({
"inputs": self.input_modality,
"targets": self.target_modality,
"input_reward": self.reward_modality,
"target_reward": self.reward_modality,
"input_action": self.action_modality,
"target_action": self.action_modality,
"target_policy": modalities.ModalityType.IDENTITY,
"target_value": modalities.ModalityType.IDENTITY,
})
p.vocab_size.update({
"inputs": self.input_vocab_size,
"targets": self.target_vocab_size,
"input_reward": self.num_rewards,
"target_reward": self.num_rewards,
"input_action": self.num_actions,
"target_action": self.num_actions,
"target_policy": None,
"target_value": None,
})
p.input_space_id = problem.SpaceID.GENERIC
p.target_space_id = problem.SpaceID.GENERIC
@property
def agent_id(self):
return self._agent_id
@agent_id.setter
def agent_id(self, agent_id):
# Lets us call agent_id with integers that we increment.
agent_id = str(agent_id)
# We use `-` in self.dataset_filename, disallow it here for convenience.
if "-" in agent_id:
raise ValueError("agent_id shouldn't have - in it.")
self._agent_id = agent_id
def dataset_filename(self):
return "{}-{}".format(self.name, self.agent_id)
@property
def num_shards(self):
return {
problem.DatasetSplit.TRAIN: 10,
problem.DatasetSplit.EVAL: 1,
}
def _generate_time_steps(self, trajectory_list):
"""A generator to yield single time-steps from a list of trajectories."""
for single_trajectory in trajectory_list:
assert isinstance(single_trajectory, trajectory.Trajectory)
# Skip writing trajectories that have only a single time-step -- this
# could just be a repeated reset.
if single_trajectory.num_time_steps <= 1:
continue
for index, time_step in enumerate(single_trajectory.time_steps):
# The first time-step doesn't have reward/processed_reward, if so, just
# setting it to 0.0 / 0 should be OK.
raw_reward = time_step.raw_reward
if not raw_reward:
raw_reward = 0.0
processed_reward = time_step.processed_reward
if not processed_reward:
processed_reward = 0
action = time_step.action
if action is None:
# The last time-step doesn't have action, and this action shouldn't be
# used, gym's spaces have a `sample` function, so let's just sample an
# action and use that.
action = self.action_space.sample()
action = gym_spaces_utils.gym_space_encode(self.action_space, action)
if six.PY3:
# py3 complains that, to_example cannot handle np.int64 !
action_dtype = self.action_space.dtype
if action_dtype in [np.int64, np.int32]:
action = list(map(int, action))
elif action_dtype in [np.float64, np.float32]:
action = list(map(float, action))
# same with processed_reward.
processed_reward = int(processed_reward)
assert time_step.observation is not None
yield {
TIMESTEP_FIELD: [index],
ACTION_FIELD:
action,
# to_example errors on np.float32
RAW_REWARD_FIELD: [float(raw_reward)],
PROCESSED_REWARD_FIELD: [processed_reward],
# to_example doesn't know bools
DONE_FIELD: [int(time_step.done)],
OBSERVATION_FIELD:
gym_spaces_utils.gym_space_encode(self.observation_space,
time_step.observation),
}
def generate_data(self, data_dir, tmp_dir, task_id=-1):
# List of files to generate data in.
# NOTE: We don't want to shuffle, so we mark the files as shuffled.
files_list = []
for split, num_shards in self.num_shards.items():
files_list.extend(self.data_filepaths(split, data_dir, num_shards, True))
# At this point some trajectories haven't finished. However we still want to
# write those down.
# A simple way of doing this is to call `self.reset()` here, this will make
# all the envs take one (extra) step, but would be a clean way to do it.
#
# self.reset()
self.trajectories.complete_all_trajectories()
# Write the completed data into these files
num_completed_trajectories = self.trajectories.num_completed_trajectories
num_shards = len(files_list)
if num_completed_trajectories < num_shards:
logging.warning(
"Number of completed trajectories [%d] is less than "
"the number of shards [%d], some shards maybe empty.",
num_completed_trajectories, num_shards)
for i, f in enumerate(files_list[:num_completed_trajectories]):
# Start at index i of completed trajectories and take every `num_shards`
# trajectory. This ensures that the data is approximately a balanced
# partition of completed trajectories, also because of the above slicing
# of files_list, i will be a valid index into completed_trajectories.
trajectories_to_write = self.trajectories.completed_trajectories[
i::num_shards]
# Convert each trajectory from `trajectories_to_write` to a sequence of
# time-steps and then send that generator to `generate_files`.
# `cycle_every_n` isn't needed since file list given to it is a singleton.
generator_utils.generate_files(
self._generate_time_steps(trajectories_to_write), [f])
def print_state(self):
for t in self.trajectories.trajectories:
print("---------")
if not t.is_active:
print("trajectory isn't active.")
continue
last_obs = t.last_time_step.observation
print(str(last_obs))
================================================
FILE: tensor2tensor/envs/env_problem_utils.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to deal with EnvProblem."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import time
import gym
import numpy as np
from tensor2tensor.envs import gym_env_problem
from tensor2tensor.envs import rendered_env_problem
from tensor2tensor.rl import gym_utils
def done_indices(dones):
"""Calculates the indices where dones has True."""
return np.argwhere(dones).squeeze(axis=1)
def play_env_problem_randomly(env_problem, num_steps):
"""Plays the env problem by randomly sampling actions for `num_steps`."""
# Reset all environments.
env_problem.reset()
# Play all environments, sampling random actions each time.
for _ in range(num_steps):
# Sample batch_size actions from the action space and stack them.
actions = np.stack([
env_problem.action_space.sample() for _ in range(env_problem.batch_size)
])
# Execute actions, observations are stored in `env_problem`.
_, _, dones, _ = env_problem.step(actions)
# Get the indices where we are done and reset those.
env_problem.reset(indices=done_indices(dones))
def get_completed_trajectories_from_env(env,
n_trajectories,
raw_trajectory=False):
"""Returns completed `n_trajectories` from `env`."""
# Just the raw trajectories.
if raw_trajectory:
return env.trajectories.completed_trajectories[:n_trajectories]
# The numpy version of the above.
completed_trajectories = []
for trajectory in env.trajectories.completed_trajectories[:n_trajectories]:
completed_trajectories.append(trajectory.as_numpy)
return completed_trajectories
def play_env_problem_with_policy(env,
policy_fun,
num_trajectories=1,
max_timestep=None,
reset=True,
state=None,
rng=None,
temperature=1.0,
boundary=32,
len_history_for_policy=32,
num_to_keep=1,
abort_fn=None,
raw_trajectory=False):
"""Plays the given env with the policy function to collect trajectories.
Args:
env: environment object, should be a subclass of env_problem.EnvProblem.
policy_fun: callable, taking in observations((B, RT) + OBS) and returning
back log-probabilities (B, AT, A).
num_trajectories: int, number of trajectories to collect.
max_timestep: int or None, if not None or a negative number, we cut any
trajectory that exceeds this time put it in the completed bin, and *dont*
reset the env.
reset: bool, true if we want to reset the envs. The envs are also reset if
max_max_timestep is None or < 0.
state: the state for `policy_fn`.
rng: jax rng, splittable.
temperature: float, temperature used in Gumbel sampling.
boundary: int, pad the sequences to the multiples of this number.
len_history_for_policy: int or None, the maximum history to keep for
applying the policy on. If None, use the whole history.
num_to_keep: int, while truncating trajectory how many time-steps to keep.
abort_fn: callable, If not None, then at every step call and abort the
trajectory collection if it returns True, if so reset the env and return
None.
raw_trajectory: bool, if True a list of trajectory.Trajectory objects is
returned, otherwise a list of numpy representations of
`trajectory.Trajectory` is returned.
Returns:
A tuple, (trajectories, number of completed trajectories). Where
trajectories is a list of triples of (observation, action, reward) ndarrays.
"""
def gumbel_sample(log_probs):
"""Gumbel sampling."""
u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
g = -np.log(-np.log(u))
return np.argmax((log_probs / temperature) + g, axis=-1)
# We need to reset all environments, if we're coming here the first time.
if reset or max_timestep is None or max_timestep <= 0:
env.reset()
else:
# Clear completed trajectories held internally.
env.trajectories.clear_completed_trajectories()
num_done_trajectories = 0
policy_application_total_time = 0
env_actions_total_time = 0
bare_env_run_time = 0
while env.trajectories.num_completed_trajectories < num_trajectories:
# Check if we should abort and return nothing.
if abort_fn and abort_fn():
# We should also reset the environment, since it will have some
# trajectories (complete and incomplete) that we want to discard.
env.reset()
return None, 0, {}, state
# Get all the observations for all the active trajectories.
# Shape is (B, RT) + OBS
# Bucket on whatever length is needed.
padded_observations, lengths = env.trajectories.observations_np(
boundary=boundary,
len_history_for_policy=len_history_for_policy)
B = padded_observations.shape[0] # pylint: disable=invalid-name
assert B == env.batch_size
assert (B,) == lengths.shape
t1 = time.time()
log_probs, value_preds, state, rng = policy_fun(
padded_observations, lengths, state=state, rng=rng)
policy_application_total_time += (time.time() - t1)
assert B == log_probs.shape[0]
actions = gumbel_sample(log_probs)
if isinstance(env.action_space, gym.spaces.Discrete):
actions = np.squeeze(actions, axis=1)
# Step through the env.
t1 = time.time()
_, _, dones, env_infos = env.step(
actions,
infos={
"log_prob_actions": log_probs,
"value_predictions": value_preds,
})
env_actions_total_time += (time.time() - t1)
bare_env_run_time += sum(
info["__bare_env_run_time__"] for info in env_infos)
# Count the number of done trajectories, the others could just have been
# truncated.
num_done_trajectories += np.sum(dones)
# Get the indices where we are done ...
done_idxs = done_indices(dones)
# ... and reset those.
t1 = time.time()
if done_idxs.size:
env.reset(indices=done_idxs)
env_actions_total_time += (time.time() - t1)
if max_timestep is None or max_timestep < 1:
continue
# Are there any trajectories that have exceeded the time-limit we want.
lengths = env.trajectories.trajectory_lengths
exceeded_time_limit_idxs = done_indices(lengths > max_timestep)
# If so, reset these as well.
t1 = time.time()
if exceeded_time_limit_idxs.size:
# This just cuts the trajectory, doesn't reset the env, so it continues
# from where it left off.
env.truncate(indices=exceeded_time_limit_idxs, num_to_keep=num_to_keep)
env_actions_total_time += (time.time() - t1)
# We have the trajectories we need, return a list of triples:
# (observations, actions, rewards)
completed_trajectories = get_completed_trajectories_from_env(
env, num_trajectories, raw_trajectory=raw_trajectory)
timing_info = {
"trajectory_collection/policy_application": policy_application_total_time,
"trajectory_collection/env_actions": env_actions_total_time,
"trajectory_collection/env_actions/bare_env": bare_env_run_time,
}
timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()}
return completed_trajectories, num_done_trajectories, timing_info, state
def make_env(batch_size=1,
env_problem_name="",
resize=True,
resize_dims=(105, 80),
max_timestep="None",
clip_rewards=True,
parallelism=1,
use_tpu=False,
num_actions=None,
rendered_env=True,
**env_kwargs):
"""Creates the env."""
if clip_rewards:
env_kwargs.update({"reward_range": (-1, 1), "discrete_rewards": True})
else:
env_kwargs.update({"discrete_rewards": False})
# TODO(henrykm) - below someone linked "resize" with "abnormality"
# Probably we need more nuanced concept of "abnormality"
# decoupled from "resize". Currently the resize flag implies
# that we switch from a generic env to a wrapped env.
# Overall this file and gym_utils.py look like good candidates
# for a refactor.
# No resizing needed, so let's be on the normal EnvProblem.
if not resize: # None or False
return gym_env_problem.GymEnvProblem(
base_env_name=env_problem_name,
batch_size=batch_size,
parallelism=parallelism,
**env_kwargs)
try:
max_timestep = int(max_timestep)
except Exception: # pylint: disable=broad-except
max_timestep = None
wrapper_fn = functools.partial(
gym_utils.gym_env_wrapper, **{
"rl_env_max_episode_steps": max_timestep,
"maxskip_env": True,
"rendered_env": rendered_env,
"rendered_env_resize_to": resize_dims,
"sticky_actions": False,
"output_dtype": np.int32 if use_tpu else None,
"num_actions": num_actions,
})
return rendered_env_problem.RenderedEnvProblem(
base_env_name=env_problem_name,
batch_size=batch_size,
parallelism=parallelism,
rendered_env=rendered_env,
env_wrapper_fn=wrapper_fn,
**env_kwargs)
================================================
FILE: tensor2tensor/envs/env_problem_utils_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for env_problem_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.envs import env_problem_utils
from tensor2tensor.envs import gym_env_problem
from tensor2tensor.envs import tic_tac_toe_env # pylint: disable=unused-import
from tensor2tensor.envs import tic_tac_toe_env_problem
import tensorflow.compat.v1 as tf
class EnvProblemUtilsTest(tf.test.TestCase):
def test_play_env_problem_randomly(self):
batch_size = 5
num_steps = 100
ep = tic_tac_toe_env_problem.TicTacToeEnvProblem()
ep.initialize(batch_size=batch_size)
env_problem_utils.play_env_problem_randomly(ep, num_steps)
# We've played num_steps * batch_size steps + everytime we get 'done' we
# create another step + batch_size number of pending steps.
self.assertEqual(
num_steps * batch_size + len(ep.trajectories.completed_trajectories) +
batch_size, ep.trajectories.num_time_steps)
def test_play_env_problem_with_policy(self):
env = gym_env_problem.GymEnvProblem(
base_env_name="CartPole-v0", batch_size=2, reward_range=(-1, 1))
# Let's make sure that at-most 4 observations come to the policy function.
len_history_for_policy = 4
def policy_fun(observations, lengths, state=None, rng=None):
del lengths
b = observations.shape[0]
# Assert that observations from time-step len_history_for_policy onwards
# are zeros.
self.assertTrue(
np.all(observations[:, len_history_for_policy:, ...] == 0))
self.assertFalse(
np.all(observations[:, :len_history_for_policy, ...] == 0))
a = env.action_space.n
p = np.random.uniform(size=(b, 1, a))
p = np.exp(p)
p = p / np.sum(p, axis=-1, keepdims=True)
return np.log(p), np.mean(p, axis=-1), state, rng
max_timestep = 15
num_trajectories = 2
trajectories, _, _, _ = env_problem_utils.play_env_problem_with_policy(
env,
policy_fun,
num_trajectories=num_trajectories,
max_timestep=max_timestep,
len_history_for_policy=len_history_for_policy)
self.assertEqual(num_trajectories, len(trajectories))
# Check shapes within trajectories.
traj = trajectories[0]
T = traj[1].shape[0] # pylint: disable=invalid-name
self.assertEqual((T + 1, 4), traj[0].shape) # (4,) is OBS
self.assertEqual((T,), traj[2].shape)
self.assertEqual(T, len(traj[4]["log_prob_actions"]))
self.assertEqual(T, len(traj[4]["value_predictions"]))
self.assertLessEqual(T, max_timestep)
traj = trajectories[1]
T = traj[1].shape[0] # pylint: disable=invalid-name
self.assertEqual((T + 1, 4), traj[0].shape)
self.assertEqual((T,), traj[2].shape)
self.assertEqual(T, len(traj[4]["log_prob_actions"]))
self.assertEqual(T, len(traj[4]["value_predictions"]))
self.assertLessEqual(T, max_timestep)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/envs/gym_env_problem.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for envs that store their history.
EnvProblem subclasses Problem and also implements the Gym interface (step,
reset, render, close, seed)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import multiprocessing.pool
import time
from absl import logging
import gym
import numpy as np
from tensor2tensor.envs import env_problem
from tensor2tensor.envs import trajectory
# This is a compatibility shim introduced to support NumPy 1.24. See:
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
def _stack(xs):
try:
return np.stack(xs)
except ValueError:
return np.stack(np.asarray(xs, dtype=object))
class GymEnvProblem(env_problem.EnvProblem):
"""An EnvProblem implemented as a batch of gym envs.
This implementation should work well for cases where the env is not batched by
default ex: any gym env. In this case we create `batch_size` number of envs
and store them in a list. Any function then that interacts with the envs, like
reset, step or close goes over the env list to do the needful, ex: when reset
is called with specific indices we reset only those indices, etc.
The usage of this class will look like the following:
# 1. Creates and initializes the env_problem.
ep = env_problem.EnvProblem(...)
# 2. One needs to call reset() at the start, this resets all envs.
ep.reset()
# 3. Call step with actions for all envs, i.e. len(action) = batch_size
obs, rewards, dones, infos = ep.step(actions)
# 4. Figure out which envs got done and reset only those.
ep.reset(indices=env_problem_utils.done_indices(dones))
# 5. Go back to Step #3 to further interact with the env or just dump the
# generated data to disk by calling:
ep.generate_data(...)
# 6. If we now need to use this object again to play a few more iterations
# perhaps with a different batch size or maybe not recording the data, then
# we need to re-initialize environments and do some book-keeping, call:
ep.initialize_environments(batch_size)
# 7. Go back to Step #2, i.e. reset all envs.
NOTE: Look at `EnvProblemTest.test_interaction_with_env` and/or
`EnvProblemTest.test_generate_data`
NOTE: We rely heavily that the underlying environments expose a gym style
interface, i.e. in addition to reset(), step() and close() we have access to
the following properties: observation_space, action_space, reward_range.
"""
def __init__(self,
base_env_name=None,
env_wrapper_fn=None,
reward_range=None,
**kwargs):
"""Initializes this class by creating the envs and managing trajectories.
Args:
base_env_name: (string) passed to `gym.make` to make the underlying
environment.
env_wrapper_fn: (callable(env): env) Applies gym wrappers to the base
environment.
reward_range: (tuple(number, number) or None) the first element is the
minimum reward and the second is the maximum reward, used to clip and
process the raw reward in `process_rewards`. If None, this is inferred
from the inner environments.
**kwargs: (dict) Arguments passed to the base class.
"""
# Name for the base environment, will be used in `gym.make` in
# the default implementation of `initialize_environments`.
self._base_env_name = base_env_name
# An env generates data when it is given actions by an agent which is either
# a policy or a human -- this is supposed to be the `id` of the agent.
#
# In practice, this is used only to store (and possibly retrieve) history
# to an appropriate directory.
self._agent_id = "default"
# We clip rewards to this range before processing them further, as described
# in `process_rewards`.
self._reward_range = reward_range
# Initialize the environment(s).
# This can either be a list of environments of len `batch_size` or this can
# be a Neural Network, in which case it will be fed input with first
# dimension = `batch_size`.
self._envs = None
self._pool = None
self._env_wrapper_fn = env_wrapper_fn
# Call the super's ctor. It will use some of the member fields, so we call
# it in the end.
super(GymEnvProblem, self).__init__(**kwargs)
@property
def base_env_name(self):
return self._base_env_name
def _verify_same_spaces(self):
"""Verifies that all the envs have the same observation and action space."""
# Pre-conditions: self._envs is initialized.
if self._envs is None:
raise ValueError("Environments not initialized.")
if not isinstance(self._envs, list):
logging.warning("Not checking observation and action space "
"compatibility across envs, since there is just one.")
return
# NOTE: We compare string representations of observation_space and
# action_space because compositional classes like space.Tuple don't return
# true on object comparison.
if not all(
str(env.observation_space) == str(self.observation_space)
for env in self._envs):
err_str = ("All environments should have the same observation space, but "
"don't.")
logging.error(err_str)
# Log all observation spaces.
for i, env in enumerate(self._envs):
logging.error("Env[%d] has observation space [%s]", i,
env.observation_space)
raise ValueError(err_str)
if not all(
str(env.action_space) == str(self.action_space) for env in self._envs):
err_str = "All environments should have the same action space, but don't."
logging.error(err_str)
# Log all action spaces.
for i, env in enumerate(self._envs):
logging.error("Env[%d] has action space [%s]", i, env.action_space)
raise ValueError(err_str)
def initialize_environments(self,
batch_size=1,
parallelism=1,
per_env_kwargs=None,
**kwargs):
"""Initializes the environments.
Args:
batch_size: (int) Number of `self.base_env_name` envs to initialize.
parallelism: (int) If this is greater than one then we run the envs in
parallel using multi-threading.
per_env_kwargs: (list or None) An optional list of dictionaries to pass to
gym.make. If not None, length should match `batch_size`.
**kwargs: (dict) Kwargs to pass to gym.make.
"""
assert batch_size >= 1
if per_env_kwargs is not None:
assert batch_size == len(per_env_kwargs)
else:
per_env_kwargs = [{} for _ in range(batch_size)]
# By now `per_env_kwargs` is a list of dictionaries of size batch_size.
# The individual dictionaries maybe empty.
def union_dicts(dict1, dict2):
"""Union `dict1` and `dict2`."""
copy_dict1 = copy.copy(dict1)
copy_dict1.update(dict2)
return copy_dict1
self._envs = [
gym.make(self.base_env_name,
**union_dicts(kwargs, env_kwarg))
for env_kwarg in per_env_kwargs
]
self._parallelism = parallelism
self._pool = multiprocessing.pool.ThreadPool(self._parallelism)
if self._env_wrapper_fn is not None:
self._envs = list(map(self._env_wrapper_fn, self._envs))
self._verify_same_spaces()
# If self.reward_range is None, i.e. this means that we should take the
# reward range of the env.
if self.reward_range is None:
self._reward_range = self._envs[0].reward_range
# This data structure stores the history of each env.
#
# NOTE: Even if the env is a NN and can step in all batches concurrently, it
# is still valuable to store the trajectories separately.
self._trajectories = trajectory.BatchTrajectory(batch_size=batch_size)
def assert_common_preconditions(self):
# Asserts on the common pre-conditions of:
# - self._envs is initialized.
# - self._envs is a list.
assert self._envs
assert isinstance(self._envs, list)
@property
def observation_space(self):
return self._envs[0].observation_space
@property
def action_space(self):
return self._envs[0].action_space
@property
def reward_range(self):
return self._reward_range
def seed(self, seed=None):
if not self._envs:
logging.info("`seed` called on non-existent envs, doing nothing.")
return None
if not isinstance(self._envs, list):
logging.warning("`seed` called on non-list envs, doing nothing.")
return None
logging.warning(
"Called `seed` on EnvProblem, calling seed on the underlying envs.")
for env in self._envs:
env.seed(seed)
return super(GymEnvProblem, self).seed(seed=seed)
def close(self):
if not self._envs:
logging.info("`close` called on non-existent envs, doing nothing.")
return
if not isinstance(self._envs, list):
logging.warning("`close` called on non-list envs, doing nothing.")
return
# Call close on all the envs one by one.
for env in self._envs:
env.close()
def _reset(self, indices):
"""Resets environments at indices shouldn't pre-process or record.
Args:
indices: list of indices of underlying envs to call reset on.
Returns:
np.ndarray of stacked observations from the reset-ed envs.
"""
# This returns a numpy array with first dimension `len(indices)` and the
# rest being the dimensionality of the observation.
num_envs_to_reset = len(indices)
observations = [None] * num_envs_to_reset
def reset_at(idx):
observations[idx] = self._envs[indices[idx]].reset()
if self._parallelism > 1:
self._pool.map(reset_at, range(num_envs_to_reset))
else:
for i in range(num_envs_to_reset):
reset_at(i)
return _stack(observations)
def _step(self, actions):
"""Takes a step in all environments, shouldn't pre-process or record.
Args:
actions: (np.ndarray) with first dimension equal to the batch size.
Returns:
a tuple of stacked raw observations, raw rewards, dones and infos.
"""
assert len(actions) == len(self._envs)
observations = [None] * self.batch_size
rewards = [None] * self.batch_size
dones = [None] * self.batch_size
infos = [{} for _ in range(self.batch_size)]
def apply_step(i):
t1 = time.time()
observations[i], rewards[i], dones[i], infos[i] = self._envs[i].step(
actions[i])
t2 = time.time()
infos[i]["__bare_env_run_time__"] = t2 - t1
if self._parallelism > 1:
self._pool.map(apply_step, range(self.batch_size))
else:
for i in range(self.batch_size):
apply_step(i)
# Convert each list (observations, rewards, ...) into np.array and return a
# tuple.
return tuple(map(_stack, [observations, rewards, dones, infos]))
================================================
FILE: tensor2tensor/envs/gym_env_problem_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.gym_env_problem."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import gym
from gym.spaces import Box
from gym.spaces import Discrete
import numpy as np
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.envs import env_problem
from tensor2tensor.envs import env_problem_utils
from tensor2tensor.envs import gym_env_problem
from tensor2tensor.layers import modalities
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class GymEnvProblemTest(tf.test.TestCase):
def setUp(self):
self.tmp_dir = os.path.join(tf.test.get_temp_dir(), "tmp_dir")
tf.gfile.MakeDirs(self.tmp_dir)
def tearDown(self):
tf.gfile.DeleteRecursively(self.tmp_dir)
def test_setup(self):
ep = gym_env_problem.GymEnvProblem(
base_env_name="CartPole-v0", batch_size=5)
# Checks that environments were created and they are `batch_size` in number.
ep.assert_common_preconditions()
# Expectations on the observation space.
observation_space = ep.observation_space
self.assertIsInstance(observation_space, Box)
self.assertEqual(observation_space.shape, (4,))
self.assertEqual(observation_space.dtype, np.float32)
# Expectations on the action space.
action_space = ep.action_space
self.assertTrue(isinstance(action_space, Discrete))
self.assertEqual(action_space.shape, ())
self.assertEqual(action_space.dtype, np.int64)
self.assertEqual(ep.num_actions, 2)
# Reward range is infinite here.
self.assertFalse(ep.is_reward_range_finite)
def test_reward_range(self):
# Passing reward_range=None means take the reward range of the underlying
# environment as the reward range.
ep = gym_env_problem.GymEnvProblem(
base_env_name="FrozenLake-v1", batch_size=5, reward_range=None)
ep.assert_common_preconditions()
# Assert reward range is finite here.
self.assertTrue(ep.is_reward_range_finite)
# Assert that it is as expected of the underlying environment, since reward_
self.assertEqual(0, ep.reward_range[0])
self.assertEqual(1, ep.reward_range[1])
def test_default_processed_rewards_discrete(self):
# This differs in the above because it has a Tuple observation space.
ep = gym_env_problem.GymEnvProblem(
base_env_name="KellyCoinflip-v0", batch_size=5, reward_range=None)
ep.assert_common_preconditions()
# Assert reward range is finite here.
self.assertTrue(ep.is_reward_range_finite)
# Assert that it is as expected of the underlying environment.
reward_range = ep.reward_range
self.assertEqual(0, reward_range[0])
# Google's version of Gym has maxWealth, vs max_wealth externally.
max_wealth = getattr(ep._envs[0], "maxWealth",
getattr(ep._envs[0], "max_wealth", None))
self.assertIsNotNone(max_wealth)
self.assertEqual(max_wealth, reward_range[1])
# Check that the processed rewards are discrete.
self.assertTrue(ep.is_processed_rewards_discrete)
# Assert on the number of rewards.
self.assertEqual(ep.num_rewards, reward_range[1] - reward_range[0] + 1)
def test_interaction_with_env(self):
batch_size = 5
reward_range = (-1, 1)
ep = gym_env_problem.GymEnvProblem(
base_env_name="KellyCoinflip-v0",
batch_size=batch_size,
reward_range=reward_range)
# Resets all environments.
ep.reset()
# Let's play a few steps.
nsteps = 100
num_trajectories_completed = 0
num_timesteps_completed = 0
# If batch_done_at_step[i] = j then it means that i^th env last got done at
# step = j.
batch_done_at_step = np.full(batch_size, -1)
for i in range(nsteps):
# Sample batch_size actions from the action space and stack them (since
# that is the expected type).
actions = np.stack([ep.action_space.sample() for _ in range(batch_size)])
_, _, dones, _ = ep.step(actions)
# Do the book-keeping on number of trajectories completed and expect that
# it matches ep's completed number.
num_done = sum(dones)
num_trajectories_completed += num_done
self.assertEqual(num_trajectories_completed,
len(ep.trajectories.completed_trajectories))
# Get the indices where we are done ...
done_indices = env_problem_utils.done_indices(dones)
# ... and reset those.
ep.reset(indices=done_indices)
# If nothing got done, go on to the next step.
if done_indices.size == 0:
# i.e. this is an empty array.
continue
# See when these indices were last done and calculate how many time-steps
# each one took to get done.
num_timesteps_completed += sum(i + 1 - batch_done_at_step[done_indices])
batch_done_at_step[done_indices] = i
# This should also match the number of time-steps completed given by ep.
num_timesteps_completed_ep = sum(
ct.num_time_steps for ct in ep.trajectories.completed_trajectories)
self.assertEqual(num_timesteps_completed, num_timesteps_completed_ep)
# Reset the trajectories.
ep.trajectories.reset_batch_trajectories()
self.assertEqual(0, len(ep.trajectories.completed_trajectories))
def read_tfrecord_dataset(self, filenames, ep):
# Read the dataset at `filenames` into a tf.data.Dataset and returns the
# number of time-steps (just the number of records in the dataset) and the
# number of trajectories.
last_timestep = -1
num_time_steps = 0
num_trajectories = 0
for ex in generator_utils.tfrecord_iterator(
filenames, example_spec=ep.example_reading_spec()[0]):
num_time_steps += 1
this_timestep = ex[env_problem.TIMESTEP_FIELD][0]
if 1 + last_timestep != this_timestep:
num_trajectories += 1
self.assertEqual(0, this_timestep)
last_timestep = this_timestep
num_trajectories += 1
return num_trajectories, num_time_steps
def play_env(self,
env=None,
nsteps=100,
base_env_name=None,
batch_size=5,
reward_range=None):
"""Creates `GymEnvProblem` with the given arguments and plays it randomly.
Args:
env: optional env.
nsteps: plays the env randomly for nsteps.
base_env_name: passed to GymEnvProblem's init.
batch_size: passed to GymEnvProblem's init.
reward_range: passed to GymEnvProblem's init.
Returns:
tuple of gym_env_problem, number of trajectories done,
number of trajectories done in the last step.
"""
if env is None:
env = gym_env_problem.GymEnvProblem(
base_env_name=base_env_name,
batch_size=batch_size,
reward_range=reward_range)
# Usually done by a registered subclass, we do this manually in the test.
env.name = base_env_name
# Reset all environments.
env.reset()
# Play for some steps to generate data.
num_dones = 0
num_dones_in_last_step = 0
for _ in range(nsteps):
# Sample actions.
actions = np.stack([env.action_space.sample() for _ in range(batch_size)])
# Step through it.
_, _, dones, _ = env.step(actions)
# Get the indices where we are done ...
done_indices = env_problem_utils.done_indices(dones)
# ... and reset those.
env.reset(indices=done_indices)
# count the number of dones we got, in this step and overall.
num_dones_in_last_step = sum(dones)
num_dones += num_dones_in_last_step
return env, num_dones, num_dones_in_last_step
def test_generate_data(self):
base_env_name = "CartPole-v0"
batch_size = 5
reward_range = (-1, 1)
nsteps = 100
ep, num_dones, num_dones_in_last_step = self.play_env(
base_env_name=base_env_name,
batch_size=batch_size,
reward_range=reward_range,
nsteps=nsteps)
# This is because every num_dones starts a new trajectory, and a further
# batch_size are active at the last step when we call generate_data, but
# the ones that got done in the last step (these have only one time-step in
# their trajectory) will be skipped.
expected_num_trajectories = num_dones + batch_size - num_dones_in_last_step
# Similar logic as above, nsteps * batch_size overall `step` calls are made.
expected_num_time_steps = (
nsteps * batch_size) + num_dones + batch_size - num_dones_in_last_step
# Dump the completed data to disk.
ep.generate_data(self.tmp_dir, self.tmp_dir)
# Read the written files and assert on the number of time steps.
training_filenames = ep.training_filepaths(
self.tmp_dir, ep.num_shards[problem.DatasetSplit.TRAIN], True)
dev_filenames = ep.dev_filepaths(
self.tmp_dir, ep.num_shards[problem.DatasetSplit.EVAL], True)
training_trajectories, training_timesteps = self.read_tfrecord_dataset(
training_filenames, ep)
dev_trajectories, dev_timesteps = self.read_tfrecord_dataset(
dev_filenames, ep)
# This tests what we wrote on disk matches with what we computed.
self.assertEqual(expected_num_time_steps,
training_timesteps + dev_timesteps)
self.assertEqual(expected_num_trajectories,
training_trajectories + dev_trajectories)
def test_problem_dataset_works(self):
# We need to derive this class to set the required methods.
class TestEnv(gym_env_problem.GymEnvProblem):
name = "TestEnv"
@property
def input_modality(self):
return modalities.ModalityType.REAL_L2_LOSS
@property
def input_vocab_size(self):
return None
@property
def target_modality(self):
return modalities.ModalityType.SYMBOL_WEIGHTS_ALL
@property
def target_vocab_size(self):
return 2
@property
def action_modality(self):
return modalities.ModalityType.SYMBOL_WEIGHTS_ALL
@property
def reward_modality(self):
return modalities.ModalityType.SYMBOL_WEIGHTS_ALL
base_env_name = "CartPole-v0"
batch_size = 5
reward_range = (-1, 1)
env = TestEnv(
base_env_name=base_env_name,
batch_size=batch_size,
reward_range=reward_range)
nsteps = 100
ep, _, _ = self.play_env(env=env, nsteps=nsteps)
# Dump the completed data to disk.
ep.generate_data(self.tmp_dir, self.tmp_dir)
# Read the actual files and count the trajectories and time-steps.
dev_filenames = ep.dev_filepaths(
self.tmp_dir, ep.num_shards[problem.DatasetSplit.EVAL], True)
dev_trajectories, dev_timesteps = self.read_tfrecord_dataset(
dev_filenames, ep)
# Count them using a tf.data.Dataset.
dev_dataset = ep.dataset(tf_estimator.ModeKeys.EVAL, data_dir=self.tmp_dir)
last_timestep = -1
dev_timesteps_ds = 0
dev_trajectories_ds = 0
iterator = dev_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as session:
while True:
try:
tf_example_dict = session.run(next_element)
# We have a time-step.
dev_timesteps_ds += 1
this_timestep = tf_example_dict[env_problem.TIMESTEP_FIELD][
0] # [0] since every value in tf_example_dict is an array/list.
if 1 + last_timestep != this_timestep:
dev_trajectories_ds += 1
self.assertEqual(0, this_timestep)
last_timestep = this_timestep
except tf.errors.OutOfRangeError:
dev_trajectories_ds += 1
break
# Make sure that they agree.
self.assertEqual(dev_trajectories, dev_trajectories_ds)
self.assertEqual(dev_timesteps, dev_timesteps_ds)
def test_resets_properly(self):
base_env_name = "CartPole-v0"
batch_size = 5
reward_range = (-1, 1)
nsteps = 100
env = gym_env_problem.GymEnvProblem(
base_env_name=base_env_name,
batch_size=batch_size,
reward_range=reward_range)
env.name = base_env_name
num_dones = 0
while num_dones == 0:
env, num_dones, _ = self.play_env(env=env,
nsteps=nsteps,
batch_size=batch_size,
reward_range=reward_range)
# Some completed trajectories have been generated.
self.assertGreater(env.trajectories.num_completed_trajectories, 0)
# This should clear the env completely of any state.
env.reset()
# Assert that there aren't any completed trajectories in the env now.
self.assertEqual(env.trajectories.num_completed_trajectories, 0)
def test_per_env_kwargs(self):
# Creating a dummy class where we specify the action at which the env
# returns done.
class TestPerEnvKwargsEnv(gym.Env):
"""Test environment with the `done action` specified."""
action_space = Discrete(3)
observation_space = Box(low=-1.0, high=1.0, shape=())
def __init__(self, done_action=0):
self._done_action = done_action
def _generate_ob(self):
return self.observation_space.sample()
def step(self, action):
done = self._done_action == action
reward = 1 if done else 0
return (self._generate_ob(), reward, done, {})
def reset(self):
return self._generate_ob()
# Registering it with gym.
test_env_name = "TestPerEnvKwargsEnv-v0"
gym.envs.register(id=test_env_name, entry_point=TestPerEnvKwargsEnv)
# Creating a batch of those with different done actions.
base_env_name = test_env_name
batch_size = 2
reward_range = (-1, 1)
per_env_kwargs = [{"done_action": 1}, {"done_action": 2}]
env = gym_env_problem.GymEnvProblem(
base_env_name=base_env_name,
batch_size=batch_size,
reward_range=reward_range,
per_env_kwargs=per_env_kwargs)
_ = env.reset()
# Finally querying the done actions.
_, _, d, _ = env.step(np.array([0, 0]))
self.assertFalse(d[0])
self.assertFalse(d[1])
_, _, d, _ = env.step(np.array([1, 2]))
self.assertTrue(d[0])
self.assertTrue(d[1])
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/envs/gym_spaces_utils.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Few utility functions to deal with gym spaces.
gym.spaces.Box and gym.spaces.Discrete are easiest to support.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Box
from gym.spaces import Discrete
import numpy as np
import tensorflow.compat.v1 as tf
def box_space_spec(box_space, tf_dtype):
return tf.FixedLenFeature(box_space.shape, tf_dtype)
def discrete_space_spec(discrete_space, tf_dtype):
del discrete_space # this is not needed.
return tf.FixedLenFeature((1,), tf_dtype)
def gym_space_spec(gym_space):
"""Returns a reading spec of a gym space.
NOTE: Only implemented currently for Box and Discrete.
Args:
gym_space: instance of gym.spaces whose spec we want.
Returns:
Reading spec for that space.
Raises:
NotImplementedError: For spaces whose reading spec we haven't implemented.
"""
# First try to determine the type.
try:
tf_dtype = tf.as_dtype(gym_space.dtype)
except TypeError as e:
tf.logging.error("Cannot convert space's type [%s] to tf.dtype",
gym_space.dtype)
raise e
# Now hand it over to the specialized functions.
if isinstance(gym_space, Box):
return box_space_spec(gym_space, tf_dtype)
elif isinstance(gym_space, Discrete):
return discrete_space_spec(gym_space, tf_dtype)
else:
raise NotImplementedError
def gym_space_encode(gym_space, observation):
# We should return something that generator_utils.to_example can consume.
if isinstance(gym_space, Discrete):
return [observation]
if isinstance(gym_space, Box):
return observation.reshape(-1).tolist()
raise NotImplementedError
def cardinality(gym_space):
"""Number of elements that can be represented by the space.
Makes the most sense for Discrete or Box type with integral dtype, ex: number
of actions in an action space.
Args:
gym_space: The gym space.
Returns:
np.int64 number of observations that can be represented by this space, or
returns None when this doesn't make sense, i.e. float boxes etc.
Raises:
NotImplementedError when a space's cardinality makes sense but we haven't
implemented it.
"""
if (gym_space.dtype == np.float32) or (gym_space.dtype == np.float64):
tf.logging.warn("Returning None for a float gym space's cardinality: %s",
gym_space)
return None
if isinstance(gym_space, Discrete):
return gym_space.n
if isinstance(gym_space, Box):
# Construct a box with all possible values in this box and take a product.
return np.prod(gym_space.high - gym_space.low + 1)
raise NotImplementedError
================================================
FILE: tensor2tensor/envs/gym_spaces_utils_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for gym_spaces_utils.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Box
from gym.spaces import Discrete
import numpy as np
from tensor2tensor.envs import gym_spaces_utils
import tensorflow.compat.v1 as tf
class GymSpacesUtilsTest(tf.test.TestCase):
def test_discrete_space_spec(self):
discrete_space = Discrete(100)
spec = gym_spaces_utils.gym_space_spec(discrete_space)
self.assertIsInstance(spec, tf.FixedLenFeature)
self.assertEqual(spec.dtype, tf.int64)
self.assertListEqual(list(spec.shape), [1])
def test_box_space_spec(self):
box_space = Box(low=0, high=10, shape=[5, 6], dtype=np.float32)
spec = gym_spaces_utils.gym_space_spec(box_space)
self.assertIsInstance(spec, tf.FixedLenFeature)
self.assertEqual(spec.dtype, tf.float32)
self.assertListEqual(list(spec.shape), [5, 6])
def test_discrete_space_encode(self):
discrete_space = Discrete(100)
value = discrete_space.sample()
encoded_value = gym_spaces_utils.gym_space_encode(discrete_space, value)
self.assertListEqual([value], encoded_value)
def test_box_space_encode(self):
box_space = Box(low=0, high=10, shape=[2], dtype=np.int64)
value = np.array([2, 3])
encoded_value = gym_spaces_utils.gym_space_encode(box_space, value)
self.assertListEqual([2, 3], encoded_value)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: tensor2tensor/envs/mujoco_problems.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mujoco Gym environments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensor2tensor.envs import rendered_env_problem
from tensor2tensor.layers import modalities
from tensor2tensor.rl import gym_utils
from tensor2tensor.utils import registry
@registry.register_env_problem
class ReacherEnvProblem(rendered_env_problem.RenderedEnvProblem):
"""Mujoco's reacher environment."""
def __init__(self):
base_env_name = "Reacher-v2"
wrapper_fn = functools.partial(
gym_utils.gym_env_wrapper, **{
"rl_env_max_episode_steps": -1,
"maxskip_env": False,
"rendered_env": True,
"rendered_env_resize_to": None, # Do not resize frames
"sticky_actions": False,
"output_dtype": None,
"num_actions": None,
})
super(ReacherEnvProblem, self).__init__(
base_env_name=base_env_name, env_wrapper_fn=wrapper_fn)
@property
def input_modality(self):
return modalities.ModalityType.VIDEO
@property
def target_modality(self):
return modalities.ModalityType.VIDEO
@property
def action_modality(self):
return modalities.ModalityType.IDENTITY
@property
def reward_modality(self):
return modalities.ModalityType.IDENTITY
@property
def input_vocab_size(self):
return 256
@property
def target_vocab_size(self):
return 256
================================================
FILE: tensor2tensor/envs/mujoco_problems_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.mujoco_problems."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.envs import env_problem_utils
from tensor2tensor.envs import mujoco_problems # pylint: disable=unused-import
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
class ReacherEnvProblemTest(tf.test.TestCase):
def test_registration_and_interaction_with_env_problem(self):
batch_size = 5
# This ensures that registration has occurred.
ep = registry.env_problem("reacher_env_problem", batch_size=batch_size)
ep.reset()
num_done = 0
nsteps = 100
for _ in range(nsteps):
actions = np.stack([ep.action_space.sample() for _ in range(batch_size)])
obs, rewards, dones, infos = ep.step(actions)
# Assert that things are happening batchwise.
self.assertEqual(batch_size, len(obs))
self.assertEqual(batch_size, len(rewards))
self.assertEqual(batch_size, len(dones))
self.assertEqual(batch_size, len(infos))
done_indices = env_problem_utils.done_indices(dones)
ep.reset(done_indices)
num_done += sum(dones)
# Assert that something got done atleast,
self.assertGreater(num_done, 0)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/envs/rendered_env_problem.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for env problems with RGB array as observation space."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import png
import six
from tensor2tensor.data_generators import video_utils
from tensor2tensor.envs import env_problem
from tensor2tensor.envs import gym_env_problem
from tensor2tensor.utils import contrib
import tensorflow.compat.v1 as tf
_IMAGE_ENCODED_FIELD = "image/encoded"
_IMAGE_FORMAT_FIELD = "image/format"
_IMAGE_HEIGHT_FIELD = "image/height"
_IMAGE_WIDTH_FIELD = "image/width"
_FRAME_NUMBER_FIELD = "frame_number"
_FORMAT = "png"
class RenderedEnvProblem(gym_env_problem.GymEnvProblem,
video_utils.VideoProblem):
"""An `EnvProblem` when observations are RGB arrays.
This takes care of wrapping a rendered gym environment to behave like a
`VideoProblem`. This class assumes the underlying gym environment is either a
`gym_utils.RenderedEnv` or it natively returns rendered scene for
observations. i.e. the underlying gym environment should have a
`Box` observation space with the following shape: [frame_height, frame_width,
channels]
Note: The method resolution order for this class is:
`RenderedEnvProblem`, `EnvProblem`, `Env`, `VideoProblem`, `Problem`
"""
def __init__(self, *args, **kwargs):
"""Initialize by calling both parents' constructors."""
gym_env_problem.GymEnvProblem.__init__(self, *args, **kwargs)
video_utils.VideoProblem.__init__(self)
def initialize_environments(self,
batch_size=1,
parallelism=1,
rendered_env=True,
per_env_kwargs=None,
**kwargs):
gym_env_problem.GymEnvProblem.initialize_environments(
self, batch_size=batch_size, parallelism=parallelism,
per_env_kwargs=per_env_kwargs, **kwargs)
# Assert the underlying gym environment has correct observation space
if rendered_env:
assert len(self.observation_spec.shape) == 3
def example_reading_spec(self):
"""Return a mix of env and video data fields and decoders."""
slim = contrib.slim()
video_fields, video_decoders = (
video_utils.VideoProblem.example_reading_spec(self))
env_fields, env_decoders = (
gym_env_problem.GymEnvProblem.example_reading_spec(self))
# Remove raw observations field since we want to capture them as videos.
env_fields.pop(env_problem.OBSERVATION_FIELD)
env_decoders.pop(env_problem.OBSERVATION_FIELD)
# Add frame number spec and decoder.
env_fields[_FRAME_NUMBER_FIELD] = tf.FixedLenFeature((1,), tf.int64)
env_decoders[_FRAME_NUMBER_FIELD] = slim.tfexample_decoder.Tensor(
_FRAME_NUMBER_FIELD)
# Add video fields and decoders
env_fields.update(video_fields)
env_decoders.update(video_decoders)
return env_fields, env_decoders
def _generate_time_steps(self, trajectory_list):
"""Transforms time step observations to frames of a video."""
for time_step in gym_env_problem.GymEnvProblem._generate_time_steps(
self, trajectory_list):
# Convert the rendered observations from numpy to png format.
frame_np = np.array(time_step.pop(env_problem.OBSERVATION_FIELD))
frame_np = frame_np.reshape(
[self.frame_height, self.frame_width, self.num_channels])
# TODO(msaffar) Add support for non RGB rendered environments
frame = png.from_array(frame_np, "RGB", info={"bitdepth": 8})
frame_buffer = six.BytesIO()
frame.save(frame_buffer)
# Put the encoded frame back.
time_step[_IMAGE_ENCODED_FIELD] = [frame_buffer.getvalue()]
time_step[_IMAGE_FORMAT_FIELD] = [_FORMAT]
time_step[_IMAGE_HEIGHT_FIELD] = [self.frame_height]
time_step[_IMAGE_WIDTH_FIELD] = [self.frame_width]
# Add the frame number
time_step[_FRAME_NUMBER_FIELD] = time_step[env_problem.TIMESTEP_FIELD]
yield time_step
@property
def num_channels(self):
return self.observation_spec.shape[2]
@property
def frame_height(self):
return self.observation_spec.shape[0]
@property
def frame_width(self):
return self.observation_spec.shape[1]
@property
def total_number_of_frames(self):
"""Upper bound on the total number of frames across all environments.
This is used to decide sharding. See `VideoProblem.total_number_of_frames`
for more details.
Returns:
number of frames among all examples in the dataset.
"""
return self.trajectories.num_time_steps
================================================
FILE: tensor2tensor/envs/rendered_env_problem_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.rendered_env_problem."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.envs import env_problem
from tensor2tensor.envs import env_problem_utils
from tensor2tensor.envs import rendered_env_problem
from tensor2tensor.envs.mujoco_problems import ReacherEnvProblem
import tensorflow.compat.v1 as tf
class RenderedEnvProblemTest(tf.test.TestCase):
def test_generate_timesteps(self):
env = ReacherEnvProblem()
env.initialize(batch_size=2)
env_problem_utils.play_env_problem_randomly(env, num_steps=5)
env.trajectories.complete_all_trajectories()
frame_number = 0
for time_step in env._generate_time_steps(
env.trajectories.completed_trajectories):
# original observation should not be in time_step
self.assertNotIn(env_problem.OBSERVATION_FIELD, time_step)
# validate frame
self.assertIn(rendered_env_problem._IMAGE_ENCODED_FIELD, time_step)
self.assertIn(rendered_env_problem._IMAGE_HEIGHT_FIELD, time_step)
self.assertIn(rendered_env_problem._IMAGE_WIDTH_FIELD, time_step)
self.assertIn(rendered_env_problem._IMAGE_FORMAT_FIELD, time_step)
self.assertIn(rendered_env_problem._FRAME_NUMBER_FIELD, time_step)
decoded_frame = tf.image.decode_png(
time_step[rendered_env_problem._IMAGE_ENCODED_FIELD][0])
decoded_frame = self.evaluate(decoded_frame)
self.assertListEqual(
[env.frame_height, env.frame_width, env.num_channels],
list(decoded_frame.shape))
self.assertListEqual([rendered_env_problem._FORMAT],
time_step[rendered_env_problem._IMAGE_FORMAT_FIELD])
self.assertListEqual([frame_number],
time_step[rendered_env_problem._FRAME_NUMBER_FIELD])
self.assertListEqual([env.frame_width],
time_step[rendered_env_problem._IMAGE_WIDTH_FIELD])
self.assertListEqual([env.frame_height],
time_step[rendered_env_problem._IMAGE_HEIGHT_FIELD])
frame_number += 1
frame_number %= 6
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/envs/tic_tac_toe_env.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gym Tic-Tac-Toe environment.
Environment acts like the second player and first player is either environment
or the agent. The environment follows a random policy for now.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from tensor2tensor.data_generators import problem
from tensor2tensor.layers import modalities
from tensor2tensor.rl import gym_utils
def encode_pos(i, j):
"""Encodes a pair (i, j) as a scalar position on the board."""
return 3 * i + j
def decode_pos(pos):
"""Decoes a scalar position on the board as a pair (i, j)."""
return pos // 3, pos % 3
def get_open_spaces(board):
"""Given a representation of the board, returns a list of open spaces."""
open_spaces = []
for i in range(3):
for j in range(3):
if board[i][j] == 0:
open_spaces.append(encode_pos(i, j))
return open_spaces
def get_reward_and_done(board):
"""Given a representation of the board, returns reward and done."""
# Returns (reward, done) where:
# reward: -1 means lost, +1 means win, 0 means draw or continuing.
# done: True if the game is over, i.e. someone won or it is a draw.
# Sum all rows ...
all_sums = [np.sum(board[i, :]) for i in range(3)]
# ... all columns
all_sums.extend([np.sum(board[:, i]) for i in range(3)])
# and both diagonals.
all_sums.append(np.sum([board[i, i] for i in range(3)]))
all_sums.append(np.sum([board[i, 2 - i] for i in range(3)]))
if -3 in all_sums:
return -1, True
if 3 in all_sums:
return 1, True
done = True
if get_open_spaces(board):
done = False
return 0, done
# TODO(afrozm): This should eventually subclass Problem.
class TicTacToeEnv(gym.Env):
"""Simple TicTacToe Env, starts the game randomly half of the time."""
def __init__(self, strict=False):
self.strict = strict
# What about metadata and spec?
self.reward_range = (-1.0, 1.0)
# Action space -- 9 positions that we can chose to mark.
self.action_space = spaces.Discrete(9)
# Observation space -- this hopefully does what we need.
self.observation_space = spaces.Box(
low=-1, high=1, shape=(3, 3), dtype=np.int64)
# Set the seed.
self.np_random = None
self.seed()
# Start the game.
self.board_state = None
self.done = False
self.reset()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
# TODO(afrozm): Parametrize by some policy so that the env plays in an optimal
# way.
def play_random_move(self):
# Select open spaces.
open_spaces = get_open_spaces(self.board_state)
if not open_spaces:
return False
# Choose a space and mark it.
pos = self.np_random.choice(open_spaces)
i, j = decode_pos(pos)
self.board_state[i, j] = -1
def reset(self):
self.board_state = np.zeros((3, 3), dtype=np.int64)
# We"ll start with a 50% chance.
if self.np_random.choice([0, 1]) == 0:
self.play_random_move()
# Return the observation.
return self.board_state
def render(self, mode="human"):
# Unused.
del mode
board_str = ""
for i in range(3):
for j in range(3):
pos = self.board_state[i, j]
if pos == -1:
board_str += "x"
elif pos == 0:
board_str += "-"
else:
board_str += "o"
board_str += "\n"
return board_str
def step(self, action):
# Are we already done?
if self.strict:
assert not self.done
# Action has to belong to the action state.
assert self.action_space.contains(action)
# Is it a legitimate move, i.e. is that position open to play?
is_legit_move = action in get_open_spaces(self.board_state)
# Shouldn"t be an illegal action -- is a noop if not strict.
if self.strict:
assert is_legit_move
# If strict mode is off, then let this be a noop and env not play either.
if not is_legit_move:
return self.board_state, 0, False, {}
# This is a legit move, perform the action and check if done, etc etc.
i, j = decode_pos(action)
self.board_state[i, j] = 1
reward, done = get_reward_and_done(self.board_state)
if done:
self.done = True
return self.board_state, reward, True, {}
# If not done already, play our move.
self.play_random_move()
reward, done = get_reward_and_done(self.board_state)
self.done = done
return self.board_state, reward, self.done, {}
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.modality = {
"inputs": modalities.ModalityType.IDENTITY_SYMBOL,
"targets": modalities.ModalityType.IDENTITY_SYMBOL,
}
p.vocab_size = {
"inputs": 3, # since at each box, the input is either x, o or -.
# nevermind that we have a 3x3 box.
"targets": 3, # -1, 0, 1
}
p.input_space_id = 0 # problem.SpaceID.GENERIC
p.target_space_id = 0 # problem.SpaceID.GENERIC
# TODO(afrozm): Figure out how to get rid of this.
class DummyPolicyProblemTTT(problem.Problem):
"""Dummy Problem for running the policy."""
def __init__(self):
super(DummyPolicyProblemTTT, self).__init__()
self._ttt_env = TicTacToeEnv()
def hparams(self, defaults, model_hparams):
# Update the env's hparams.
self._ttt_env.hparams(defaults, model_hparams)
# Do these belong here?
defaults.modality.update({
"input_action": modalities.ModalityType.SYMBOL_WEIGHTS_ALL,
"input_reward": modalities.ModalityType.SYMBOL_WEIGHTS_ALL,
"target_action": modalities.ModalityType.SYMBOL_WEIGHTS_ALL,
"target_reward": modalities.ModalityType.SYMBOL_WEIGHTS_ALL,
"target_policy": modalities.ModalityType.IDENTITY,
"target_value": modalities.ModalityType.IDENTITY,
})
defaults.vocab_size.update({
"input_action": self.num_actions,
"input_reward": 3, # -1, 0, +1 ?
"target_action": self.num_actions,
"target_reward": 3, # -1, 0, +1 ?
"target_policy": None,
"target_value": None,
})
@property
def num_actions(self):
return self._ttt_env.action_space.n
def register():
# Register this with gym.
unused_tictactoe_id, unused_tictactoe_env = gym_utils.register_gym_env(
"tensor2tensor.envs.tic_tac_toe_env:TicTacToeEnv", version="v0")
# TODO(afrozm): Fix the registration and make it automatic.
register()
================================================
FILE: tensor2tensor/envs/tic_tac_toe_env_problem.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TicTacToeEnvProblem wraps the TicTacToeEnv in an EnvProblem."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.envs import gym_env_problem
from tensor2tensor.layers import modalities
from tensor2tensor.utils import registry
@registry.register_env_problem
class TicTacToeEnvProblem(gym_env_problem.GymEnvProblem):
"""Plays `batch_size` games of tic-tac-toe."""
def __init__(self):
super(TicTacToeEnvProblem, self).__init__(
base_env_name="T2TEnv-TicTacToeEnv-v0",
reward_range=(-1, 1))
@property
def input_modality(self):
return modalities.ModalityType.IDENTITY_SYMBOL
@property
def input_vocab_size(self):
# Since a box can be either x or o or empty.
return 3
@property
def target_modality(self):
return modalities.ModalityType.IDENTITY_SYMBOL
@property
def target_vocab_size(self):
# Since reward is either -1 or 0 or +1.
return 3
@property
def action_modality(self):
return modalities.ModalityType.SYMBOL_WEIGHTS_ALL
================================================
FILE: tensor2tensor/envs/tic_tac_toe_env_problem_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.tic_tac_toe_env_problem."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.envs import env_problem_utils
from tensor2tensor.envs import tic_tac_toe_env # pylint: disable=unused-import
from tensor2tensor.envs import tic_tac_toe_env_problem # pylint: disable=unused-import
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
class TicTacToeEnvProblemTest(tf.test.TestCase):
def test_registration_and_interaction_with_env_problem(self):
batch_size = 5
# This ensures that registration has occurred.
ep = registry.env_problem("tic_tac_toe_env_problem", batch_size=batch_size)
ep.reset()
num_done, num_lost, num_won, num_draw = 0, 0, 0, 0
nsteps = 100
for _ in range(nsteps):
actions = np.stack([ep.action_space.sample() for _ in range(batch_size)])
obs, rewards, dones, infos = ep.step(actions)
# Assert that things are happening batchwise.
self.assertEqual(batch_size, len(obs))
self.assertEqual(batch_size, len(rewards))
self.assertEqual(batch_size, len(dones))
self.assertEqual(batch_size, len(infos))
done_indices = env_problem_utils.done_indices(dones)
ep.reset(done_indices)
num_done += sum(dones)
for r, d in zip(rewards, dones):
if not d:
continue
if r == -1:
num_lost += 1
elif r == 0:
num_draw += 1
elif r == 1:
num_won += 1
else:
raise ValueError("reward should be -1, 0, 1 but is {}".format(r))
# Assert that something got done atleast, without that the next assert is
# meaningless.
self.assertGreater(num_done, 0)
# Assert that things are consistent.
self.assertEqual(num_done, num_won + num_lost + num_draw)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/envs/tic_tac_toe_env_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.tic_tac_toe_env."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.envs import tic_tac_toe_env as ttt_env
import tensorflow.compat.v1 as tf
class TicTacToeEnvTest(tf.test.TestCase):
def test_start(self):
ttt = ttt_env.TicTacToeEnv(strict=True)
self.assertFalse(ttt.done)
# At max one move may have been played by the env.
spaces = ttt_env.get_open_spaces(ttt.board_state)
num_open_spaces = len(spaces)
# i.e. either 8 or 9
self.assertGreater(num_open_spaces, 7)
# Play a move
observation, reward, done, unused_info = ttt.step(spaces[0])
# The environment should also have played a move.
spaces = ttt_env.get_open_spaces(observation)
self.assertEqual(num_open_spaces - 2, len(spaces))
# Since at-max 3 moves have been played, the game can't end.
self.assertEqual(reward, 0)
self.assertFalse(done)
def test_env_actions(self):
# Environment keeps taking actions and not us, we should eventually lose.
ttt = ttt_env.TicTacToeEnv(strict=True)
for _ in range(9):
ttt.play_random_move()
if ttt.done:
break
reward, done = ttt_env.get_reward_and_done(ttt.board_state)
self.assertEqual(-1, reward)
self.assertTrue(done)
def test_keep_playing(self):
ttt = ttt_env.TicTacToeEnv(strict=False)
done = False
while not done:
# sample an action from the action space.
action = ttt.action_space.sample()
# play it -- could be a no-op since we don't see if positions are empty.
unused_observation, reward, done, unused_info = ttt.step(action)
# done is True, so either:
# we won
# env won or
# no space left
we_won = reward == 1
env_won = reward == -1
space = bool(ttt_env.get_open_spaces(ttt.board_state))
self.assertTrue(we_won or env_won or not space)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: tensor2tensor/envs/time_step.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimeStep is a simple class that holds the information seen at a time-step.
Let:
r_t = Reward(s_{t-1}, a_{t-1}, s_t) - reward for getting into a state.
d_t = Done(s_t) - is this state terminal.
a_t = Action performed at state s_t
i_t = (optional) Dictionary of key, value pairs of miscellaneous data.
Then the sequence of states, actions and rewards looks like the following:
s0, a0/i0 s1/r1/d1, a1/i1 s2/r2/d2, a2/i2 s3/r3/d3, ...
TimeStep holds (s_t, d_t, r_t, a_t, i_t).
NOTE: When we call step on an environment at time-step t, we supply a_t and in
return the env gives us s_{t+1}, d_{t+1}, r_{t+1}
So, we'd have to add the actions a_t/i_t to the current time-step, but add the
observations, rewards and dones to a new time-step.
NOTE: wrt `info` - A good solution could be to have two additional fields in
TimeStep - structured algo_info (a namedtuple, possibly different for every
algorithm, or None if we don't use any) and unstructured env_info (a dict).))
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
class TimeStep(
collections.namedtuple(
"TimeStep",
["observation", "done", "raw_reward", "processed_reward", "action",
"info"])):
"""This class represents the time-step as mentioned above."""
def replace(self, **kwargs):
"""Exposes the underlying namedtuple replace."""
# NOTE: This RETURNS a NEW time-step with the replacements, i.e. doesn't
# modify self, since namedtuple is immutable.
# This allows this to be called like ts.replace(action=a, raw_reward=r) etc.
return self._replace(**kwargs)
@classmethod
def create_time_step(cls,
observation=None,
done=False,
raw_reward=None,
processed_reward=None,
action=None,
info=None):
"""Creates a TimeStep with both rewards and actions as optional."""
return cls(observation, done, raw_reward, processed_reward, action,
info)
================================================
FILE: tensor2tensor/envs/time_step_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.time_step."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.envs import time_step
import tensorflow.compat.v1 as tf
class TimeStepTest(tf.test.TestCase):
def test_create_time_step(self):
ts = time_step.TimeStep.create_time_step(
observation=1, done=True, raw_reward=1.0, processed_reward=1, action=1,
info={1: 1, 2: 4})
self.assertEqual(1, ts.observation)
self.assertTrue(ts.done)
self.assertNear(1.0, ts.raw_reward, 1e-6)
self.assertEqual(1, ts.processed_reward)
self.assertEqual(1, ts.action)
self.assertEqual({1: 1, 2: 4}, ts.info)
def test_replace(self):
ts = time_step.TimeStep.create_time_step(observation=1, action=1)
self.assertFalse(ts.done)
tsr = ts.replace(action=2, done=True, info={1: 1, 2: 4})
# Asert that ts didn't change.
self.assertFalse(ts.done)
self.assertEqual(1, ts.observation)
self.assertEqual(1, ts.action)
# But tsr is as expected.
self.assertTrue(tsr.done)
self.assertEqual(1, tsr.observation) # unchanged
self.assertEqual(2, tsr.action) # changed
self.assertEqual({1: 1, 2: 4}, tsr.info)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: tensor2tensor/envs/trajectory.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trajectory manages a sequence of TimeSteps.
BatchTrajectory manages a batch of trajectories, also keeping account of
completed trajectories.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import re
import sys
import time
from absl import logging
import cloudpickle
import numpy as np
from tensor2tensor.envs import time_step
import tensorflow.compat.v1 as tf
TRAJECTORY_FILE_FORMAT = r"trajectory_epoch_{epoch}_env_id_{env_id}_temperature_{temperature}_r_{r}.pkl"
def get_pickle_module():
if sys.version_info[0] < 3:
return cloudpickle
return pickle
class Trajectory(object):
"""Basically a list of TimeSteps with convenience methods."""
def __init__(self, time_steps=None):
# Contains a list of time steps.
if time_steps is None:
self._time_steps = []
else:
self._time_steps = time_steps
def __str__(self):
if not self.time_steps:
return "Trajectory[]"
return "Trajectory[{}]".format(", ".join(str(ts) for ts in self.time_steps))
def add_time_step(self, **create_time_step_kwargs):
"""Creates a time-step and appends it to the list.
Args:
**create_time_step_kwargs: Forwarded to
time_step.TimeStep.create_time_step.
"""
ts = time_step.TimeStep.create_time_step(**create_time_step_kwargs)
assert isinstance(ts, time_step.TimeStep)
self._time_steps.append(ts)
def change_last_time_step(self, **replace_time_step_kwargs):
"""Replace the last time-steps with the given kwargs."""
# Pre-conditions: self._time_steps shouldn't be empty.
assert self._time_steps
self._time_steps[-1] = self._time_steps[-1].replace(
**replace_time_step_kwargs)
def truncate(self, num_to_keep=1):
"""Truncate trajectories, keeping the last `num_to_keep` time-steps."""
# We return `ts_copy` back to the truncator.
ts_copy = self._time_steps[:]
# We keep the last few observations.
self._time_steps = self._time_steps[-num_to_keep:]
# NOTE: We will need to set the rewards to 0, to eliminate double counting.
for i in range(self.num_time_steps):
self._time_steps[i] = self._time_steps[i].replace(
raw_reward=0, processed_reward=0)
return Trajectory(time_steps=ts_copy)
@property
def last_time_step(self):
# Pre-conditions: self._time_steps shouldn't be empty.
assert self._time_steps
return self._time_steps[-1]
@property
def num_time_steps(self):
return len(self._time_steps)
@property
def is_active(self):
return bool(self.num_time_steps)
@property
def time_steps(self):
return self._time_steps
@property
def done(self):
return self.is_active and self.last_time_step.done
# TODO(afrozm): Add discounting and rewards-to-go when it makes sense.
@property
def reward(self):
"""Returns a tuple of sum of raw and processed rewards."""
raw_rewards, processed_rewards = 0, 0
for ts in self.time_steps:
# NOTE: raw_reward and processed_reward are None for the first time-step.
if ts.raw_reward is not None:
raw_rewards += ts.raw_reward
if ts.processed_reward is not None:
processed_rewards += ts.processed_reward
return raw_rewards, processed_rewards
@property
def observations_np(self):
return np.stack([ts.observation for ts in self.time_steps])
def last_n_observations_np(self, n=None):
if n is not None:
n = -n # pylint: disable=invalid-unary-operand-type
return np.stack([ts.observation for ts in self.time_steps[n:]])
@property
def actions_np(self):
# The last action is None, so let's skip it.
return np.stack([ts.action for ts in self.time_steps[:-1]])
@property
def info_np(self):
if not self.time_steps or not self.time_steps[0].info:
return None
info_np_dict = {}
for info_key in self.time_steps[0].info:
# Same as actions, the last info is missing, so we skip it.
info_np_dict[info_key] = np.stack(
[ts.info[info_key] for ts in self.time_steps[:-1]])
return info_np_dict
@property
def rewards_np(self):
# The first reward is None, so let's skip it.
return np.stack([ts.processed_reward for ts in self.time_steps[1:]])
@property
def raw_rewards_np(self):
return np.stack([ts.raw_reward for ts in self.time_steps[1:]])
@property
def as_numpy(self):
# TODO(afrozm): Return a named tuple here, ex: TrajectoryArrays
return (self.observations_np, self.actions_np, self.rewards_np,
self.raw_rewards_np, self.info_np)
class BatchTrajectory(object):
"""Basically a batch of active trajectories and a list of completed ones."""
def __init__(self,
batch_size=1,
trajectories=None,
completed_trajectories=None):
self.batch_size = batch_size
# Stores trajectories that are currently active, i.e. aren't done or reset.
self._trajectories = trajectories or [
Trajectory() for _ in range(self.batch_size)
]
# Stores trajectories that are completed.
# NOTE: We don't track the index this came from, as it's not needed, right?
self._completed_trajectories = completed_trajectories or []
def reset_batch_trajectories(self):
self.__init__(batch_size=self.batch_size)
def __str__(self):
string = "BatchTrajectory["
for i, t in enumerate(self.trajectories):
string += "Trajectory {} = {}\n".format(i, str(t))
for i, t in enumerate(self.completed_trajectories):
string += "Completed Trajectory {} = {}\n".format(i, str(t))
return string + "]"
@property
def trajectories(self):
return self._trajectories
@property
def completed_trajectories(self):
return self._completed_trajectories
def clear_completed_trajectories(self, num=None):
"""Clear the first `num` completed trajectories, or all if num is None."""
if num is None:
self._completed_trajectories = []
else:
self._completed_trajectories = self._completed_trajectories[num:]
def _complete_trajectory(self, trajectory, index):
"""Completes the given trajectory at the given index."""
assert isinstance(trajectory, Trajectory)
# This *should* be the case.
assert trajectory.last_time_step.action is None
# Add to completed trajectories.
self._completed_trajectories.append(trajectory)
# Make a new one to replace it.
self._trajectories[index] = Trajectory()
def truncate_trajectories(self, indices, num_to_keep=1):
"""Truncate trajectories at specified indices.
This puts the truncated trajectories in the completed list and makes new
trajectories with the observation from the trajectory that was truncated at
the same index.
Args:
indices: iterable with the indices to truncate.
num_to_keep: int, number of last time-steps to keep while truncating.
"""
for index in indices:
trajectory = self._trajectories[index]
assert trajectory.is_active, "Trajectory to truncate can't be inactive."
# Now `trajectory` just consists of the last `num_to_keep` observations
# and actions. Rewards are zeroed out.
# The old data is placed in `old_trajectory`.
old_trajectory = trajectory.truncate(num_to_keep=num_to_keep)
# We put the old data in _completed_trajectories.
self._completed_trajectories.append(old_trajectory)
def reset(self, indices, observations):
"""Resets trajectories at given indices and populates observations.
Reset can either be called right at the beginning, when there are no
time-steps, or to reset a currently active trajectory.
If resetting a currently active trajectory then we save it in
self._completed_trajectories.
Args:
indices: 1-D np.ndarray stating the indices to reset.
observations: np.ndarray of shape (indices len, obs.shape) of observations
"""
# Pre-conditions: indices, observations are np arrays.
# : indices is one-dimensional.
# : their first dimension (batch) is the same.
assert isinstance(indices, np.ndarray)
assert len(indices.shape) == 1
assert isinstance(observations, np.ndarray)
assert indices.shape[0] == observations.shape[0]
for index, observation in zip(indices, observations):
trajectory = self._trajectories[index]
# Are we starting a new trajectory at the given index?
if not trajectory.is_active:
# Then create a new time-step here with the given observation.
trajectory.add_time_step(observation=observation)
# That's all we need to do here.
continue
# If however we are resetting a currently active trajectory then we need
# to put that in self._completed_trajectories and make a new trajectory
# with the current observation.
# TODO(afrozm): Should we mark these are done? Or is the done=False and
# this being the last time-step in the trajectory good enough to recognize
# that this was reset?
# Mark trajectory as completed and move into completed_trajectories.
self._complete_trajectory(trajectory, index)
# Put the observation in the newly created trajectory.
# TODO(afrozm): Add 0 reward.
self._trajectories[index].add_time_step(observation=observation)
def complete_all_trajectories(self):
"""Essentially same as reset, but we don't have observations."""
for index in range(self.batch_size):
trajectory = self._trajectories[index]
# TODO(pkozakowski): This assertion breaks something in SimPLe trajectory
# collection code - we're probably doing something wrong there. Commenting
# out the assertion as a temporary measure.
# assert trajectory.is_active
if trajectory.is_active:
self._complete_trajectory(trajectory, index)
def step(self,
observations,
raw_rewards,
processed_rewards,
dones,
actions,
infos=None):
"""Record the information obtained from taking a step in all envs.
Records (observation, rewards, done) in a new time-step and actions in the
current time-step.
If any trajectory gets done, we move that trajectory to
completed_trajectories.
Args:
observations: ndarray of first dimension self.batch_size, which has the
observations after we've stepped, i.e. s_{t+1} where t is the current
state.
raw_rewards: ndarray of first dimension self.batch_size containing raw
rewards i.e. r_{t+1}.
processed_rewards: ndarray of first dimension self.batch_size containing
processed rewards. i.e. r_{t+1}
dones: ndarray of first dimension self.batch_size, containing true at an
index if that env is done, i.e. d_{t+1}
actions: ndarray of first dimension self.batch_size, containing actions
applied at the current time-step, which leads to the observations
rewards and done at the next time-step, i.e. a_t
infos: (optional) a dictionary of keys and values, where all the values
have the first dimension as self.batch_size.
"""
# Pre-conditions
assert isinstance(observations, np.ndarray)
assert isinstance(raw_rewards, np.ndarray)
assert isinstance(processed_rewards, np.ndarray)
assert isinstance(dones, np.ndarray)
assert isinstance(actions, np.ndarray)
if infos:
assert isinstance(infos, dict)
# We assume that we step in all envs, i.e. not like reset where we can reset
# some envs and not others.
assert self.batch_size == observations.shape[0]
assert self.batch_size == raw_rewards.shape[0]
assert self.batch_size == processed_rewards.shape[0]
assert self.batch_size == dones.shape[0]
assert self.batch_size == actions.shape[0]
if infos:
for _, v in infos.items():
assert self.batch_size == len(v)
def extract_info_at_index(infos, index):
if not infos:
return None
return {k: v[index] for k, v in infos.items()}
for index in range(self.batch_size):
trajectory = self._trajectories[index]
# NOTE: If the trajectory isn't active, that means it doesn't have any
# time-steps in it, but we are in step, so the assumption is that it has
# a prior observation from which we are stepping away from.
# TODO(afrozm): Let's re-visit this if it becomes too restrictive.
assert trajectory.is_active
# To this trajectory's last time-step, set actions.
trajectory.change_last_time_step(
action=actions[index], info=extract_info_at_index(infos, index))
# Create a new time-step to add observation, done & rewards (no actions).
trajectory.add_time_step(
observation=observations[index],
done=dones[index],
raw_reward=raw_rewards[index],
processed_reward=processed_rewards[index])
# If the trajectory is completed, i.e. dones[index] == True, then we
# account for it right-away.
if dones[index]:
self._complete_trajectory(trajectory, index)
# NOTE: The new trajectory at `index` is going to be in-active and
# `reset` should be called on it.
assert not self._trajectories[index].is_active
@staticmethod
def _trajectory_lengths(trajectories):
return np.array([t.num_time_steps for t in trajectories])
@property
def num_completed_time_steps(self):
"""Returns the number of time-steps in completed trajectories."""
return sum(BatchTrajectory._trajectory_lengths(self.completed_trajectories))
@property
def num_time_steps(self):
"""Returns the number of time-steps in completed and incomplete trajectories."""
num_time_steps = sum(BatchTrajectory._trajectory_lengths(self.trajectories))
return num_time_steps + self.num_completed_time_steps
@property
def trajectory_lengths(self):
return BatchTrajectory._trajectory_lengths(self.trajectories)
@property
def num_completed_trajectories(self):
"""Returns the number of completed trajectories."""
return len(self.completed_trajectories)
# TODO(afrozm): Take in an already padded observation ndarray and just append
# the last time-step and adding more padding if needed.
def observations_np(self, boundary=20, len_history_for_policy=20):
"""Pads the observations in all the trajectories and returns them.
Args:
boundary: integer, Observations will be padded to (n * boundary) + 1 where
n is an integer.
len_history_for_policy: int, For each trajectory return only the last
`len_history_for_policy` observations. Set to None for all the
observations.
Returns:
padded_observations: (self.batch_size, n * boundary + 1) + OBS
"""
list_observations_np_ts = [
t.last_n_observations_np(n=len_history_for_policy)
for t in self.trajectories
]
# Every element in `list_observations_np_ts` is shaped (t,) + OBS
OBS = list_observations_np_ts[0].shape[1:] # pylint: disable=invalid-name
trajectory_lengths = np.stack(
[obs.shape[0] for obs in list_observations_np_ts])
t_max = max(trajectory_lengths)
# t_max is rounded to the next multiple of `boundary`
boundary = int(boundary)
bucket_length = boundary * int(np.ceil(float(t_max) / boundary))
def padding_config(obs):
# We're padding the first axis only, since that is the time-step.
num_to_pad = bucket_length + 1 - obs.shape[0]
return [(0, num_to_pad)] + [(0, 0)] * len(OBS)
return np.stack([
np.pad(obs, padding_config(obs), "constant")
for obs in list_observations_np_ts
]), trajectory_lengths
@staticmethod
def parse_trajectory_file_name(trajectory_file_name):
"""Parse out the trajectory file's groups and return to caller."""
base_trajectory_file_name = os.path.basename(trajectory_file_name)
trajectory_file_regexp = TRAJECTORY_FILE_FORMAT.format(
epoch="(.*)",
env_id="(.*)",
temperature="(.*)",
r="(.*)",
)
compiled_regexp = re.compile(trajectory_file_regexp)
r = compiled_regexp.match(base_trajectory_file_name)
if not r:
return None
g = r.groups()
if len(g) is not compiled_regexp.groups:
return None
# epoch, env_id, temp, random string
try:
epoch = int(g[0])
env_id = int(g[1])
temperature = float(g[2])
random_string = g[3]
except ValueError:
logging.error("Trajectory file name isn't parseable: %s",
base_trajectory_file_name)
return None
return epoch, env_id, temperature, random_string
@staticmethod
def load_from_directory(trajectory_dir,
epoch=None,
temperature=None,
n_trajectories=None,
up_sample=False,
sleep_time_secs=0.1,
max_tries=100,
wait_forever=False):
"""Load trajectories from specified dir and epoch.
Args:
trajectory_dir: (string) directory to find trajectories.
epoch: (int) epoch for which to load trajectories, if None we don't filter
on an epoch.
temperature: (float) this is used to filter the trajectory files, if None
we don't filter on temperature.
n_trajectories: (int) This is the batch size of the returned
BatchTrajectory object if one is returned. If set to None, then the
number of trajectories becomes the batch size. If set to some number,
then we wait for those many trajectory files to be available.
up_sample: (bool) If there are fewer than required (n_trajectories) number
of incomplete trajectories, then we upsample to make up the numbers.
sleep_time_secs: (float) Sleep time, to wait for min_trajectories. We
exponentially back-off this up till a maximum of 10 seconds.
max_tries: (int) The number of tries to get min_trajectories trajectories.
wait_forever: (bool) If true, overrides max_tries and waits forever.
Returns:
A BatchTrajectory object with all the constraints satisfied or None.
"""
# Modify the format to get a glob with desired epoch and temperature.
trajectory_file_glob = TRAJECTORY_FILE_FORMAT.format(
epoch=epoch if epoch is not None else "*",
env_id="*",
temperature=temperature if temperature is not None else "*",
r="*",
)
trajectory_files = tf.io.gfile.glob(
os.path.join(trajectory_dir, trajectory_file_glob))
if n_trajectories:
# We need to get `n_trajectories` number of `trajectory_files`.
# This works out to a maximum ~3hr waiting period.
while ((max_tries > 0 or wait_forever) and
len(trajectory_files) < n_trajectories):
logging.info(
"Sleeping for %s seconds while waiting for %s trajectories, found "
"%s right now.", sleep_time_secs, n_trajectories,
len(trajectory_files))
time.sleep(sleep_time_secs)
max_tries -= 1
sleep_time_secs = min(10.0, sleep_time_secs * 2)
trajectory_files = tf.io.gfile.glob(
os.path.join(trajectory_dir, trajectory_file_glob))
# We can't get the required number of files and we can't up-sample either.
if (len(trajectory_files) < n_trajectories) and not up_sample:
return None
# Sample up or down as the case maybe.
trajectory_files = list(
np.random.choice(trajectory_files, n_trajectories))
# We read and load all the files, revisit if this becomes a problem.
trajectories_buffer = []
for trajectory_file in trajectory_files:
with tf.io.gfile.GFile(trajectory_file, "rb") as f:
trajectory = get_pickle_module().load(f)
assert isinstance(trajectory, Trajectory)
trajectories_buffer.append(trajectory)
if not trajectories_buffer:
return None
# If n_trajectories wasn't set, then set to the number of trajectories we're
# returning.
n_trajectories = n_trajectories or len(trajectories_buffer)
# Construct and return a new BatchTrajectory object.
return BatchTrajectory(
batch_size=n_trajectories,
trajectories=[Trajectory() for _ in range(n_trajectories)],
completed_trajectories=trajectories_buffer)
================================================
FILE: tensor2tensor/envs/trajectory_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.envs.trajectory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensor2tensor.envs import time_step
from tensor2tensor.envs import trajectory
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.io import gfile
class TrajectoryTest(tf.test.TestCase):
def test_empty_trajectory(self):
t = trajectory.Trajectory()
self.assertFalse(t.is_active)
self.assertEqual(0, t.num_time_steps)
self.assertFalse(t.done)
def test_add_time_step(self):
t = trajectory.Trajectory()
t.add_time_step(observation=1, done=True)
# Test that the trajectory is now active.
self.assertTrue(t.is_active)
added_t = t.last_time_step
self.assertEqual(1, added_t.observation)
self.assertTrue(added_t.done)
self.assertIsNone(None, added_t.raw_reward)
self.assertIsNone(None, added_t.processed_reward)
self.assertIsNone(None, added_t.action)
self.assertEqual(1, t.num_time_steps)
def test_change_last_time_step(self):
t = trajectory.Trajectory()
t.add_time_step(observation=1, done=False)
t.add_time_step(observation=1, done=True)
self.assertTrue(t.is_active)
num_ts_old = t.num_time_steps
self.assertEqual(2, num_ts_old)
# Assert on what the last time-step is currently.
ts = t.last_time_step
self.assertEqual(1, ts.observation)
self.assertTrue(ts.done)
self.assertEqual(None, ts.action)
# Change the last time-step.
t.change_last_time_step(done=False, action=5)
# Assert that it changed.
ts = t.last_time_step
self.assertEqual(1, ts.observation) # unchanged, since we didn't change it.
self.assertFalse(ts.done) # was True earlier
self.assertEqual(5, ts.action) # was None earlier
# Assert on the number of steps remaining the same as before.
self.assertEqual(num_ts_old, t.num_time_steps)
def test_reward(self):
t = trajectory.Trajectory()
# first time-step doesn't have rewards, since they are on entering a state.
t.add_time_step(
observation=1, raw_reward=None, processed_reward=None, done=False)
t.add_time_step(
observation=2, raw_reward=2, processed_reward=200, done=False)
t.add_time_step(
observation=3, raw_reward=3, processed_reward=300, done=True)
raw_reward, processed_reward = t.reward
self.assertEqual(5, raw_reward)
self.assertEqual(500, processed_reward)
def test_observation_np(self):
t = trajectory.Trajectory()
ts = 5
shape = (3, 4)
for _ in range(ts):
t.add_time_step(observation=np.random.uniform(size=shape), done=False)
self.assertEqual((ts,) + shape, t.observations_np.shape)
def test_truncate_and_last_n_observations_np(self):
t = trajectory.Trajectory()
ts = 5
shape = (3, 4)
for _ in range(ts):
t.add_time_step(observation=np.random.uniform(size=shape), done=False)
original_obs = np.copy(t.observations_np)
self.assertEqual((ts,) + shape, original_obs.shape)
# Now let's just get the observations from the last 2 steps.
num_to_keep = 2
truncated_original_obs = original_obs[-num_to_keep:, ...]
# Let's get the last `num_to_keep` observations
last_n_observations_np = np.copy(t.last_n_observations_np(n=num_to_keep))
# Now truncate the trajectory and get the same.
_ = t.truncate(num_to_keep=num_to_keep)
truncated_np = np.copy(t.observations_np)
# These should be the expected length.
self.assertEqual((2,) + shape, last_n_observations_np.shape)
self.assertEqual((2,) + shape, truncated_np.shape)
# Test the last `num_to_keep` are the same.
self.assertAllEqual(truncated_np, truncated_original_obs)
self.assertAllEqual(last_n_observations_np, truncated_original_obs)
def test_as_numpy(self):
t = trajectory.Trajectory()
shape = (3, 4)
# We'll have `ts` observations and `ts-1` actions and rewards.
ts = 5
num_actions = 6
observations = np.random.uniform(size=(ts,) + shape)
actions = np.random.choice(range(num_actions), size=(ts - 1,))
rewards = np.random.choice([-1, 0, 1], size=(ts - 1,))
squares = np.arange(ts - 1)**2
cubes = np.arange(ts - 1)**3
def get_info(i):
return {"sq": squares[i], "cu": cubes[i]}
# First time-step has no reward.
t.add_time_step(
observation=observations[0],
done=False,
action=actions[0],
info=get_info(0))
for i in range(1, ts - 1):
t.add_time_step(
observation=observations[i],
done=False,
raw_reward=rewards[i - 1],
processed_reward=rewards[i - 1],
action=actions[i],
info=get_info(i))
# Last time-step has no action.
t.add_time_step(
observation=observations[-1],
done=False,
raw_reward=rewards[-1],
processed_reward=rewards[-1])
traj_np = t.as_numpy
self.assertAllEqual(observations, traj_np[0])
self.assertAllEqual(actions, traj_np[1])
self.assertAllEqual(rewards, traj_np[2])
self.assertAllEqual(squares, traj_np[4]["sq"])
self.assertAllEqual(cubes, traj_np[4]["cu"])
class BatchTrajectoryTest(tf.test.TestCase):
BATCH_SIZE = 10
OBSERVATION_SHAPE = (3, 4)
def get_random_observations_rewards_actions_dones(self, batch_size=None):
batch_size = batch_size or self.BATCH_SIZE
# Random observations, rewards, actions, done of the expected shape.
observations = np.random.rand(*((batch_size,) + self.OBSERVATION_SHAPE))
raw_rewards = np.random.randn(batch_size)
actions = np.random.randn(batch_size)
# 40% change of being done.
dones = np.random.random((batch_size,)) > 0.6
return observations, raw_rewards, actions, dones
def test_creation(self):
bt = trajectory.BatchTrajectory(batch_size=self.BATCH_SIZE)
self.assertEqual(self.BATCH_SIZE, len(bt.trajectories))
self.assertEqual(0, bt.num_completed_trajectories)
def test_reset_all(self):
bt = trajectory.BatchTrajectory(batch_size=self.BATCH_SIZE)
indices = np.arange(self.BATCH_SIZE)
observations, _, _, _ = self.get_random_observations_rewards_actions_dones()
# Call reset.
bt.reset(indices, observations)
# Assert that all trajectories are active and not done (reset never marks
# anything as done).
self.assertTrue(all(t.is_active for t in bt.trajectories))
self.assertEqual(0, bt.num_completed_trajectories)
def test_num_time_steps(self):
bt = trajectory.BatchTrajectory(batch_size=self.BATCH_SIZE)
self.assertEqual(0, bt.num_completed_time_steps)
self.assertEqual(0, bt.num_time_steps)
def test_reset_some(self):
bt = trajectory.BatchTrajectory(batch_size=self.BATCH_SIZE)
indices = np.arange(self.BATCH_SIZE // 2)
observations, _, _, _ = self.get_random_observations_rewards_actions_dones(
batch_size=self.BATCH_SIZE // 2)
# Just reset the first half.
bt.reset(indices, observations)
# So first half are active, rest aren't.
self.assertTrue(
all(t.is_active for t in bt.trajectories[:self.BATCH_SIZE // 2]))
self.assertTrue(
all(not t.is_active for t in bt.trajectories[self.BATCH_SIZE // 2:]))
# Nothing is done anyways.
self.assertEqual(0, bt.num_completed_trajectories)
def test_truncate(self):
batch_size = 1
bt = trajectory.BatchTrajectory(batch_size=batch_size)
indices = np.arange(batch_size)
observations, _, _, _ = (
self.get_random_observations_rewards_actions_dones(
batch_size=batch_size))
# Have to call reset first.
bt.reset(indices, observations)
# Take a few steps.
ts = 5
for _ in range(ts):
(observations, rewards, actions,
dones) = self.get_random_observations_rewards_actions_dones(
batch_size=batch_size)
dones[...] = False
bt.step(observations, rewards, rewards, dones, actions)
self.assertEqual(0, bt.num_completed_trajectories)
num_to_keep = 2
bt.truncate_trajectories(indices, num_to_keep=num_to_keep)
self.assertEqual(batch_size, bt.num_completed_trajectories)
# Assert they are all active.
# Since the last `num_to_keep` observations were duplicated.
self.assertTrue(all(t.is_active for t in bt.trajectories))
orig_obs = bt.completed_trajectories[0].observations_np
# + 1 because of the initial reset
self.assertEqual(ts + 1, orig_obs.shape[0])
trunc_obs = bt.trajectories[0].observations_np
self.assertEqual(num_to_keep, trunc_obs.shape[0])
self.assertEqual(num_to_keep, bt.trajectories[0].num_time_steps)
# Test that the observations are the same.
self.assertAllEqual(orig_obs[-num_to_keep:, ...], trunc_obs)
def test_step(self):
bt = trajectory.BatchTrajectory(batch_size=self.BATCH_SIZE)
indices = np.arange(self.BATCH_SIZE)
observations, _, _, _ = self.get_random_observations_rewards_actions_dones()
# Have to call reset first.
bt.reset(indices, observations)
# Create some fake data for calling step.
new_observations, raw_rewards, actions, dones = (
self.get_random_observations_rewards_actions_dones())
processed_rewards = raw_rewards.astype(np.int64)
# Force mark the first one as done anyways, so that there is something to
# test.
dones[0] = True
num_done = sum(dones)
self.assertLessEqual(1, num_done) # i.e. num_done is atleast 1.
num_not_done = len(dones) - num_done
# Finally call step.
bt.step(new_observations, raw_rewards, processed_rewards, dones, actions)
# Expect to see `num_done` number of completed trajectories.
self.assertEqual(num_done, bt.num_completed_trajectories)
# Expect to see that the rest are marked as active.
num_active = sum(t.is_active for t in bt.trajectories)
self.assertEqual(num_not_done, num_active)
def test_desired_placement_of_rewards_and_actions(self):
batch_size = 1
bt = trajectory.BatchTrajectory(batch_size=batch_size)
indices = np.arange(batch_size)
observations, _, _, _ = self.get_random_observations_rewards_actions_dones(
batch_size=batch_size)
# Have to call reset first.
bt.reset(indices, observations)
# Create some fake data for calling step.
new_observations, raw_rewards, actions, _ = (
self.get_random_observations_rewards_actions_dones(
batch_size=batch_size))
processed_rewards = raw_rewards.astype(np.int64)
dones = np.full(batch_size, False)
# Call step.
bt.step(new_observations, raw_rewards, processed_rewards, dones, actions)
# Assert that nothing is done, since dones is False
self.assertEqual(0, bt.num_completed_trajectories)
# The only trajectory is active.
self.assertEqual(batch_size, len(bt.trajectories))
t = bt.trajectories[0]
self.assertTrue(t.is_active)
self.assertEqual(2, t.num_time_steps)
ts = t.time_steps
# Now assert on placements
# i.e. the old observation/done is first and the new one comes later.
self.assertAllEqual(observations[0], ts[0].observation)
self.assertAllEqual(new_observations[0], ts[1].observation)
self.assertEqual(False, ts[0].done)
self.assertEqual(False, ts[1].done)
# Similarly actions went to the first time-step.
self.assertEqual(actions[0], ts[0].action)
self.assertIsNone(ts[1].action)
# However make sure reward went into the second time-step and not the first.
self.assertNear(raw_rewards[0], ts[1].raw_reward, 1e-6)
self.assertIsNone(ts[0].raw_reward)
# Similarly with processed_rewards.
self.assertEqual(processed_rewards[0], ts[1].processed_reward)
self.assertIsNone(ts[0].processed_reward)
def test_observations_np(self):
bt = trajectory.BatchTrajectory(batch_size=self.BATCH_SIZE)
indices = np.arange(self.BATCH_SIZE)
observations, _, _, _ = self.get_random_observations_rewards_actions_dones()
# Have to call reset first.
bt.reset(indices, observations)
# Number of time-steps now looks like the following:
# (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
lengths = np.full((self.BATCH_SIZE,), 1)
ts = 5
for _ in range(ts):
(observations, rewards, actions,
dones) = self.get_random_observations_rewards_actions_dones()
dones[...] = False
bt.step(observations, rewards, rewards, dones, actions)
# Number of time-steps now looks like the following:
# (6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
lengths = lengths + ts
# Now let's mark the first two as done.
observations, _, _, _ = self.get_random_observations_rewards_actions_dones(
batch_size=2)
bt.reset(np.array([0, 1]), observations)
# Number of time-steps now looks like the following:
# (1, 1, 6, 6, 6, 6, 6, 6, 6, 6)
lengths[0] = lengths[1] = 1
for _ in range(ts):
(observations, rewards, actions,
dones) = self.get_random_observations_rewards_actions_dones()
dones[...] = False
bt.step(observations, rewards, rewards, dones, actions)
# Number of time-steps now looks like the following:
# (6, 6, 11, 11, 11, 11, 11, 11, 11, 11)
lengths = lengths + ts
boundary = 20
len_history_for_policy = 40
padded_obs_np, padded_lengths = bt.observations_np(
boundary=boundary, len_history_for_policy=len_history_for_policy)
# The lengths are what we expect them to be.
self.assertAllEqual(lengths, padded_lengths)
# The padded_observations are the shape we expect them to be.
self.assertEqual((self.BATCH_SIZE, boundary + 1) + self.OBSERVATION_SHAPE,
padded_obs_np.shape)
# Let's now request the last n = [1, 2 * boundary) steps for the history.
for len_history_for_policy in range(1, 2 * boundary):
# The expected lengths will now be:
truncated_lengths = [min(l, len_history_for_policy) for l in lengths]
padded_obs_np, padded_lengths = bt.observations_np(
boundary=boundary, len_history_for_policy=len_history_for_policy)
self.assertAllEqual(truncated_lengths, padded_lengths)
# This shouldn't change, since even if we request lengths > boundary + 1
# there are no trajectories that long.
self.assertEqual((self.BATCH_SIZE, boundary + 1) + self.OBSERVATION_SHAPE,
padded_obs_np.shape)
# Let's do 10 more steps (to go on the other side of the boundary.
ts = 10
for _ in range(ts):
(observations, rewards, actions,
dones) = self.get_random_observations_rewards_actions_dones()
dones[...] = False
bt.step(observations, rewards, rewards, dones, actions)
# Number of time-steps now looks like the following:
# (16, 16, 21, 21, 21, 21, 21, 21, 21, 21)
lengths = lengths + ts
len_history_for_policy = 40
padded_obs_np, padded_lengths = bt.observations_np(
boundary=boundary, len_history_for_policy=len_history_for_policy)
# The lengths are what we expect them to be.
self.assertAllEqual(lengths, padded_lengths)
# The padded_observations are the shape we expect them to be.
self.assertEqual(
(self.BATCH_SIZE, (2 * boundary) + 1) + self.OBSERVATION_SHAPE,
padded_obs_np.shape)
# Test that the padding is the only part that is all 0s.
# NOTE: There is almost 0 probability that the random observation is all 0s.
zero_obs = np.full(self.OBSERVATION_SHAPE, 0.)
for b in range(self.BATCH_SIZE):
# The first lengths[b] will be actual data, rest is 0s.
for ts in range(lengths[b]):
self.assertFalse(np.all(zero_obs == padded_obs_np[b][ts]))
for ts in range(lengths[b], len(padded_obs_np[b])):
self.assertAllEqual(zero_obs, padded_obs_np[b][ts])
def test_parse_trajectory_file_name(self):
self.assertEqual(
(12, 13, 1.0, "abc"),
trajectory.BatchTrajectory.parse_trajectory_file_name(
"/tmp/trajectory_epoch_000012_env_id_000013_temperature_1.0_r_abc.pkl"
))
self.assertIsNone(
trajectory.BatchTrajectory.parse_trajectory_file_name(
"/tmp/trajectory_epoch_000012_env_id_000013.pkl"))
def test_load_from_directory(self):
output_dir = self.get_temp_dir()
epochs = [0, 1, 2]
env_ids = [0, 1, 2]
temperatures = [0.5, 1.0]
random_strings = ["a", "b"]
# Write some trajectories.
# There are 3x3x2x2 (36) trajectories, and of them 3x2x2 (12) are done.
for epoch in epochs:
for env_id in env_ids:
for temperature in temperatures:
for random_string in random_strings:
traj = trajectory.Trajectory(time_steps=[
time_step.TimeStep(
observation=epoch,
done=(epoch == 0),
raw_reward=1.0,
processed_reward=1.0,
action=env_id,
info={})
])
trajectory_file_name = trajectory.TRAJECTORY_FILE_FORMAT.format(
epoch=epoch,
env_id=env_id,
temperature=temperature,
r=random_string)
with gfile.GFile(
os.path.join(output_dir, trajectory_file_name), "w") as f:
trajectory.get_pickle_module().dump(traj, f)
# Load everything and check.
bt = trajectory.BatchTrajectory.load_from_directory(output_dir)
self.assertIsInstance(bt, trajectory.BatchTrajectory)
self.assertEqual(36, bt.num_completed_trajectories)
self.assertEqual(36, bt.batch_size)
bt = trajectory.BatchTrajectory.load_from_directory(output_dir, epoch=0)
self.assertEqual(12, bt.num_completed_trajectories)
self.assertEqual(12, bt.batch_size)
# Get 100 trajectories, but there aren't any.
bt = trajectory.BatchTrajectory.load_from_directory(
output_dir, epoch=0, n_trajectories=100, max_tries=0)
self.assertIsNone(bt)
bt = trajectory.BatchTrajectory.load_from_directory(
output_dir, epoch=0, temperature=0.5)
self.assertEqual(6, bt.num_completed_trajectories)
self.assertEqual(6, bt.batch_size)
bt = trajectory.BatchTrajectory.load_from_directory(output_dir, epoch=1)
self.assertEqual(12, bt.num_completed_trajectories)
self.assertEqual(12, bt.batch_size)
# Constraints cannot be satisfied.
bt = trajectory.BatchTrajectory.load_from_directory(
output_dir, epoch=1, n_trajectories=100, up_sample=False, max_tries=0)
self.assertIsNone(bt)
# Constraints can be satisfied.
bt = trajectory.BatchTrajectory.load_from_directory(
output_dir, epoch=1, n_trajectories=100, up_sample=True, max_tries=0)
self.assertEqual(100, bt.num_completed_trajectories)
self.assertEqual(100, bt.batch_size)
bt = trajectory.BatchTrajectory.load_from_directory(
output_dir, epoch=1, n_trajectories=10)
self.assertEqual(10, bt.num_completed_trajectories)
self.assertEqual(10, bt.batch_size)
gfile.rmtree(output_dir)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/insights/README.md
================================================
# Tensor2Tensor Insights
The Insights packages provides an interactive webservice for understanding the
inner workings of a Tensor2Tensor model. It will provide a series of
visualizations extracted from a requested T2T model that informs model developers
and model users on how to improve or best utilize a model.
## Dependencies
Before using the Insights server, you must install [Bower](https://bower.io/)
which we use to manage our web component dependencies. You can easily install
this with the [Node Package Manager](https://www.npmjs.com/).
## Setup Instructions
After training a model, such as according to the Quick Start guide, you can run
the `t2t-insights-server` binary and begin querying it.
First, prepare the bower dependencies by navigating into the
`tensor2tensor/insights/polymer` directory and running `bower install`:
```
pushd tensor2tensor/insights/polymer
bower install
popd
```
The models run by server is then configured by a JSON version of the
InsightsConfiguration protocol buffer. Using the model trained in the Quick
Start guide, a sample configuration would be:
```
{
"configuration": [{
"source_language": "en",
"target_language": "de",
"label": "transformers_wmt32k",
"transformer": {
"model": "transformer",
"model_dir": "/tmp/t2t/train",
"data_dir": "/tmp/t2t/data",
"hparams": "",
"hparams_set": "transformer_base_single_gpu",
"problem": "translate_ende_wmt32k"
}
}],
"language": [{
"code": "en",
"name": "English"
},{
"code": "de",
"name": "German"
}]
}
```
With that saved to `configuration.json`, run the following:
```
t2t-insights-server \
--configuration=configuration.json \
--static_path=`pwd`/tensor2tensor/insights/polymer
```
This will bring up a minimal [Flask](http://flask.pocoo.org/) REST service
served by a [GUnicorn](http://gunicorn.org/) HTTP Server.
## Features to be developed
This is a minimal web server. We are in the process of adding additional
exciting features that give insight into a model's behavior:
* Integrating a multi-head attention visualization.
* Registering multiple models to compare their behavior.
* Indexing training data to find examples related to a current query.
* Tracking interesting query + translation pairs for deeper analysis.
================================================
FILE: tensor2tensor/insights/__init__.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: tensor2tensor/insights/graph.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Graph representation for building decoding graph visualizations."""
class Vertex(object):
"""Vertex stores in and out edge connections to other Vertex instances.
The Vertex class supports serialization to a JSON data format expected by the
client side representation. When serializing, it generates the following
fields:
in_edge_index: The list of directed edge indices into the Vertex.
out_edge_index: The list of directed edge indices from the Vertex.
"""
def __init__(self, idx):
"""Initialize the Vertex.
Args:
idx: The index of the vertex.
"""
self.idx = idx
self.in_edges = []
self.out_edges = []
def to_dict(self):
"""Returns a simplified dictionary representing the Vertex.
Returns:
A dictionary that can easily be serialized to JSON.
"""
return {
"in_edge_index": self.in_edges,
"out_edge_index": self.out_edges,
}
class Edge(object):
"""Edge stores edge details connecting two Vertex instances.
The Edge class supports serialization to a JSON data format expected by the
client side representation. When serializing, it generates the following
fields:
source_index: The source Vertex index for this Edge.
target_index: The target Vertex index for this Edge.
data: Arbitrary data for this Edge.
"""
def __init__(self, idx):
"""Initialize the Edge.
Args:
idx: The index of the Edge.
"""
self.idx = idx
self.source = -1
self.target = -1
self.data = {}
def to_dict(self):
"""Returns a simplified dictionary representing the Vertex.
Returns:
A dictionary that can easily be serialized to JSON.
"""
return {
"source_index": self.source,
"target_index": self.target,
"data": self.data,
}
def __str__(self):
return str(self.to_dict())
class Graph(object):
"""A directed graph that can easily be JSON serialized for visualization.
When serializing, it generates the following fields:
edge: The list of all serialized Edge instances.
node: The list of all serialized Vertex instances.
"""
def __init__(self):
self.vertices = []
self.edges = []
self.vertex_map = {}
def new_vertex(self):
"""Creates and returns a new vertex.
Returns:
A new Vertex instance with a unique index.
"""
vertex = Vertex(len(self.vertices))
self.vertices.append(vertex)
return vertex
def get_vertex(self, key):
"""Returns or Creates a Vertex mapped by key.
Args:
key: A string reference for a vertex. May refer to a new Vertex in which
case it will be created.
Returns:
A the Vertex mapped to by key.
"""
if key in self.vertex_map:
return self.vertex_map[key]
vertex = self.new_vertex()
self.vertex_map[key] = vertex
return vertex
def add_edge(self, source, target):
"""Returns a new edge connecting source and target vertices.
Args:
source: The source Vertex.
target: The target Vertex.
Returns:
A new Edge linking source to target.
"""
edge = Edge(len(self.edges))
self.edges.append(edge)
source.out_edges.append(edge.idx)
target.in_edges.append(edge.idx)
edge.source = source.idx
edge.target = target.idx
return edge
def to_dict(self):
"""Returns a simplified dictionary representing the Graph.
Returns:
A dictionary that can easily be serialized to JSON.
"""
return {
"node": [v.to_dict() for v in self.vertices],
"edge": [e.to_dict() for e in self.edges]
}
================================================
FILE: tensor2tensor/insights/insight_configuration.proto
================================================
syntax = "proto3";
package tensor2tensor;
// Configures the Neural Machine Translation Insight Frontend with a set of
// supported query processors and languages.
message InsightConfiguration {
// Specifies zero or more models to inspect.
repeated QueryProcessorConfiguration configuration = 1;
// Specifies language codes and display names.
repeated Language language = 2;
}
// A displayable language name.
message Language {
// The BCP-47 Language code.
string code = 1;
// The language's display name.
string name = 2;
}
// Configures a QueryProcessor and registers it with the Insight Frontend when
// responding to analysis queries.
message QueryProcessorConfiguration {
// The model's BCP-47 source language code.
string source_language = 1;
// The model's BCP-47 target language code.
string target_language = 2;
// A short label for the model.
string label = 3;
// The QueryProcessor to use. By default we just use the TransformerModel.
string query_processor = 4;
// Configuration for the TransformerModel.
TransformerConfiguration transformer = 5;
}
// Specifies the parameters for a trained Transformer model to inspect. These
// parameters match those in t2t-trainer and t2t-decoder.
message TransformerConfiguration {
// The model type.
string model = 1;
// The trained model directory.
string model_dir = 2;
// The data directory for the model.
string data_dir = 3;
// The hyperparameter set for running the model.
string hparams_set = 4;
// Overriding hyperparameters.
string hparams = 5;
// The problem sets over which this model was trained and configured.
string problems = 6;
}
================================================
FILE: tensor2tensor/insights/polymer/.bowerrc
================================================
{
"directory": "."
}
================================================
FILE: tensor2tensor/insights/polymer/attention_visualization/attention-visualization.html
================================================
{{selectedProbability}}
================================================
FILE: tensor2tensor/insights/polymer/attention_visualization/attention-visualization.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` presents a heatmap of input-output associations.
*
* The heat map association shows source to target word association strengths
* according to some method.
*
* ### Usage
*
*
*/
class AttentionVisualization extends Polymer.Element {
constructor() {
super();
/**
* D3.js DOM element.
* @private
*/
this.container_ = undefined;
/**
* @private
*/
this.margin_ = {
top: 150,
bottom: 50,
right: 10,
left: 100
};
/**
* D3.js DOM element.
* @private
*/
this.svg_ = undefined;
/**
* D3.js DOM element.
* @private
*/
this.vis_ = undefined;
/**
* D3.js DOM element.
* @private
*/
this.zoom_ = undefined;
}
/**
* @return {string} The component name.
*/
static get is() {
return 'attention-visualization';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {AttentionData}
*/
data: {
type: Object,
observer: 'dataUpdated_',
},
/**
* @type {number}
*/
zoomDepth_: {
type: Number,
},
};
}
/**
* @return {!Array} The component observers.
*/
static get observers() {
return [
'zoomDepthChanged_(zoomDepth_)',
];
}
/**
* Sets the default zoom depth.
* @override
*/
ready() {
super.ready();
this.set('zoomDepth_', 20);
}
/**
* Sets the zoom state based on the updated depth.
* @param {number} zoomDepth the zoom depth.
* @private
*/
zoomDepthChanged_(zoomDepth) {
if (!this.container_) { return; }
if (zoomDepth == 0) {
zoomDepth = 0.000001;
}
let transform = d3.zoomTransform(this.vis_.node()).scale(zoomDepth / 20.0);
this.container_.attr("transform", transform);
}
/**
* Updates the heatmap.
* @param {!AttentionData} newData the new alignment data.
* @private
*/
dataUpdated_(newData) {
// Create the bounding areas and margins for the heatmap.
let cellDimension = 40;
let sourceTokens = newData.source_tokens;
let targetTokens = newData.target_tokens;
// Convert the attention weights to cell objects which also give access to
// the row and column indices.
let mapCells = newData.weights.map(function(d, i) {
return {
value: d,
row: Math.floor(i / targetTokens.length),
col: i % targetTokens.length
};
});
// Create the color scale.
let colorScale = d3.scaleQuantile().domain([0.0, 1.0]).range([
'#cccccc', '#b2b2b2', '#999999', '#7f7f7f',
'#666666', '#4c4c4c', '#333333', '#191919'
]);
this.zoom_ = d3.zoom().scaleExtent([1, 10]).on('zoom', zoomed.bind(this));
d3.select(this.$.chart).selectAll("*").remove();
// Create the bounding div and svgs which will contain all details.
this.svg_ = d3.select(this.$.chart)
.append('div')
.classed('svg-container', true)
.append('svg')
.attr('width', '100%')
.attr('height', '100%')
.classed('svg-content-responsive', true);
this.vis_ = this.svg_.append('g')
.attr('transform',
'translate(' + this.margin_.left + ',' + this.margin_.top + ')')
.call(this.zoom_)
.on('dblclick.zoom', null)
.on('wheel.zoom', null);
// Create a bounding rectangle upon which zooming and panning will take
// place.
this.vis_.append('rect')
.attr('width', '100%')
.attr('height', '100%')
.style('fill', 'none')
.style('pointer-events', 'all');
this.container_ = this.vis_.append('g');
// Initiate the panning and/or zooming.
function zoomed() {
this.container_.attr("transform",
d3.event.transform.scale(this.zoomDepth_ / 20.0));
}
// Place the source tokens along the vertical axis. Each token has an id
// based on it's index.
var sourceLabels = this.container_.append('g');
sourceLabels.selectAll('.source-label')
.data(sourceTokens)
.enter()
.append('text')
.text(function(d) {
return d;
})
.style('text-anchor', 'end')
.attr(
'id',
function(d, i) {
return 'row-' + i;
})
.attr('class', 'source-label mono')
.attr('transform', 'translate(-6,' + cellDimension / 1.5 + ')')
.attr('x', 0)
.attr('y', function(d, i) {
return i * cellDimension;
});
var targetLabels = this.container_.append('g');
// Place the target tokens along the horizontal axis. Each token has an id
// based on it's index.
targetLabels.selectAll('.target-label')
.data(targetTokens)
.enter()
.append('text')
.text(function(d) {
return d;
})
.style('text-anchor', 'left')
.attr(
'id',
function(d, i) {
return 'col-' + i;
})
.attr('class', 'target-label mono')
.attr(
'transform', 'translate(' + cellDimension / 2 + ',-6) rotate(-90)')
.attr(
'y',
function(d, i) {
return i * cellDimension;
})
.attr('x', 0);
// Create the heat map and populate with cells. Each cell will
// highlight when hovered over. Additionally, the column and row tokens
// will highlight to make clear which tokens are being observed. Lastly,
// each cell will trigger a popup showing details of the alignment state.
var heatMap = this.container_.append('g');
// Group the rectangle and text elements and capture the mouse events from
// both so that the rectangle can be highlighted when it's in focus.
let cellGroup = heatMap.selectAll('.cell')
.data(mapCells)
.enter()
.append('g')
.attr('class', 'cell-group')
.on('mouseover', function(d, i) {
// Highlight the newly hovered over cell and it's row/column
// tokens.
d3.select(this).classed('cell-hover', true);
sourceLabels.select('#row-' + d.row)
.classed('text-highlight', true);
targetLabels.select('#col-' + d.col)
.classed('text-highlight', true);
})
.on('mouseout', function(d) {
// Clear all highlighting.
d3.select(this).classed('cell-hover', false);
sourceLabels.select('#row-' + d.row)
.classed('text-highlight', false);
targetLabels.select('#col-' + d.col)
.classed('text-highlight', false);
});
// Add the rectangles for each cell.
cellGroup
.append('rect')
.attr(
'id',
function(d, i) {
return 'cell-' + i;
})
.attr('class', 'cell cell-border')
.attr(
'x',
function(d) {
return d.col * cellDimension;
})
.attr(
'y',
function(d) {
return d.row * cellDimension;
})
.attr('width', cellDimension)
.attr('height', cellDimension)
.style(
'fill',
function(d) {
return colorScale(d.value);
});
// Add the text for each cell.
cellGroup
.append('text')
.text(function(d) { return d.value.toFixed(2); })
.attr('class', 'weight weight-label')
.attr('x', function(d) { return 5 + (d.col * cellDimension); })
.attr('y', function(d) { return 25 + (d.row * cellDimension); });
}
/**
* Resets the pan and zoom state.
* @private
*/
reset_() {
if (!this.svg_) { return; }
this.vis_.call(this.zoom_.transform, d3.zoomIdentity);
this.set('zoomDepth_', 20);
}
}
customElements.define(AttentionVisualization.is, AttentionVisualization);
================================================
FILE: tensor2tensor/insights/polymer/bower.json
================================================
{
"name": "tensor2tensor-insights",
"homepage": "https://github.com/tensorflow/tensor2tensor",
"description": "Components for analyzing tensor2tensor neural machine translation models.",
"main": "index.html",
"keywords": [
"neural",
"machine",
"translation"
],
"authors": [
"kstevens@google.com"
],
"license": "Apache 2.0",
"private": true,
"ignore": [
"**/.*",
"node_modules",
"bower_components",
"test",
"tests"
],
"dependencies": {
"app-layout": "PolymerElements/app-layout#2.0.4",
"app-route": "PolymerElements/app-route#2.0.3",
"d3": "d3#4.12.2",
"iron-a11y-keys": "PolymerElements/iron-a11y-keys#2.0.0",
"iron-ajax": "PolymerElements/iron-ajax#2.0.0",
"iron-flex-layout": "PolymerElements/iron-flex-layout#2.0.0",
"iron-icon": "PolymerElements/iron-icon#2.0.0",
"iron-icons": "PolymerElements/iron-icons#2.0.0",
"iron-list": "PolymerElements/iron-list#2.0.0",
"iron-pages": "PolymerElements/iron-pages#2.0.0",
"iron-selector": "PolymerElements/iron-selector#2.0.0",
"neon-animation": "PolymerElements/neon-animation#2.0.0",
"paper-button": "PolymerElements/paper-button#2.0.0",
"paper-card": "PolymerElements/paper-card#2.0.0",
"paper-dialog": "PolymerElements/paper-dialog#2.0.0",
"paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#2.0.0",
"paper-icon-button": "PolymerElements/paper-icon-button#2.0.0",
"paper-input": "PolymerElements/paper-input#2.0.0",
"paper-item": "PolymerElements/paper-item#2.0.0",
"paper-listbox": "PolymerElements/paper-listbox#2.0.0",
"paper-slider": "PolymerElements/paper-slider#2.0.0",
"paper-tabs": "PolymerElements/paper-tabs#2.0.0",
"paper-toggle-button": "PolymerElements/paper-toggle-button#2.0.0",
"paper-tooltip": "PolymerElements/paper-tooltip#2.0.0",
"paper-progress": "PolymerElements/paper-progress#2.0.0",
"polymer": "polymer/polymer#v2.3.1"
},
"resolutions": {
"webcomponentsjs": "^v1.0.19",
"polymer": "^v2.3.1",
"app-route": "^2.0.3",
"app-layout": "^2.0.4",
"iron-location": "1 - 2",
"iron-selector": "^2.0.0",
"neon-animation": "^2.0.0",
"iron-icon": "^2.0.0",
"iron-pages": "^2.0.0",
"iron-icons": "^2.0.0",
"paper-icon-button": "^2.0.0",
"paper-item": "^2.0.0",
"iron-flex-layout": "^2.0.0",
"paper-listbox": "^2.0.0",
"iron-a11y-keys": "^2.0.0",
"paper-dialog": "^2.0.0",
"iron-ajax": "^2.0.0",
"paper-progress": "^2.0.0",
"paper-dropdown-menu": "^2.0.0",
"paper-tabs": "^2.0.0",
"paper-input": "^2.0.0",
"paper-toggle-button": "^2.0.0",
"paper-slider": "^2.0.0",
"iron-list": "^2.0.0",
"paper-card": "^2.0.0",
"paper-tooltip": "^2.0.0",
"iron-overlay-behavior": "^2.2.0"
}
}
================================================
FILE: tensor2tensor/insights/polymer/common-types.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @fileoverview A set of shared types that will be replaced by js proto types.
* @externs
*/
/**
* A typedef for a nlp.nmt.mt_debug_fe.LanguageConfiguration message.
* This can't be converted to javascript yet because it transitively depends on
* tensorflow protos that can't be converted to javascript.
* TODO(kstevens): Remove this typedef when we remove the dependency on
* non-convertible tensorflow protos.
* @typedef {{
* code: string,
* name: string,
* hidden: ?boolean,
* }}
*/
let Language;
/**
* A typedef for a nlp.nmt.mt_debug_fe.SerializedConfiguration message.
* This can't be converted to javascript yet because it transitively depends on
* tensorflow protos that can't be converted to javascript.
* TODO(kstevens): Remove this typedef when we remove the dependency on
* non-convertible tensorflow protos.
* @typedef {{
* id: string,
* target: string,
* source_language: Language,
* target_language: Language,
* }}
*/
let Model;
/**
* @typedef {{
* name: string,
* localProbability: number,
* cumalitiveProbability: number,
* attention: Array,
* children: Array,
* }}
*/
let TreeNode;
/**
* @typedef {{
* source_tokens: Array,
* target_tokens: Array,
* weights: !Array
* }}
*/
let AttentionData;
/**
* @typedef {{
* label: string,
* label_id: number,
* log_probability: number,
* total_log_probability: number,
* score: number,
* parent_id: number,
* }}
*/
let Candidate;
/**
* @typedef {{
* id: number,
* stepIndex: number,
* candidate: !Candidate,
* children: !Array,
* }}
*/
let InteractiveNode;
/**
* @typedef {{
* step_name: string,
* segment: !Array
* }}
*/
let QueryProcessingRewriteStep;
/**
* @typedef {{
* source_processing: !Array,
* target_processing: !Array,
* }}
*/
let QueryProcessingVisualization;
/**
* @typedef {{
* in_edge_index: !Array,
* out_edge_index: !Array,
* }}
*/
let BeamSearchNode;
/**
* @typedef {{
* label_id: number,
* label: string,
* log_probability: number,
* total_log_probability: number,
* score: number,
* completed: boolean,
* }}
*/
let BeamSearchCandidate;
/**
* @typedef {{
* source_index: number,
* target_index: number,
* data: !BeamSearchCandidate,
* }}
*/
let BeamSearchEdge;
/**
/**
* @typedef {{
* node: !Array,
* edge: !Array,
* }}
*/
let SearchGraphVisualization;
/**
* @typedef {{
* candidate_list: !Array<{
* candidate: !Array,
* }>,
* }}
*/
let GenerateCandidateResponse;
/**
* @typedef {{
* session_id: number,
* }}
*/
let StartTranslationResponse;
================================================
FILE: tensor2tensor/insights/polymer/explore_view/explore-view.html
================================================
================================================
FILE: tensor2tensor/insights/polymer/explore_view/explore-view.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` Presents a view for debuging translations.
*
* This provides an interactive interface for querying a backend service to
* fetch detailed analysis of a translation process. Each result will be
* provided as a stack.
*
* ### Usage
*
*
*/
class ExploreView extends Polymer.Element {
/**
* @return {string} The component name.
*/
static get is() {
return 'explore-view';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
route: {
type: Object,
},
/**
* @type {!Array}
*/
rules_: {
type: Array,
},
/**
* @type {?Model}
*/
model_: {
type: Object
},
/**
* @type {string}
*/
query_: {
type: Object,
}
};
}
/**
* @return {!Array} The component observers.
*/
static get observers() {
return [
'modelChanged_(queryData, model_)',
];
}
/**
* @override
*/
ready() {
super.ready();
this.set('rules_', []);
this.set('fetchingResult', false);
}
/**
* Noop
* @public
*/
refresh() {
// Noop
}
/**
* Resets the results when a model changes and triggers a query automatically
* if one exists.
* @param {?{query: string}} queryData The current route data.
* @param {?Model} model Unused, but needed for triggering.
* @private
*/
modelChanged_(queryData, model) {
if (queryData && queryData.query) {
// Compose the query from the querydata field and the path in the rest of
// the route. If the link includes an escaped "/" app-route splits the
// query and remaining path on that escaped "/". So query appears to not
// include the rest of the intended query.
let query = unescape(queryData.query) + this.get('tailRoute').path;
this.set('query_', query);
this.translate_();
}
this.set('results', []);
this.set('rules_', []);
}
/**
* Sends a translation request to the server.
* @private
*/
translate_() {
if (!this.model_ || !this.model_.id) {
return;
}
var params = {
'source': this.query_,
'id': this.model_.id,
'sl': this.model_.source_language.code,
'tl': this.model_.target_language.code,
};
var paramList = this.createBodyValue_(params);
this.set('url', '/debug?' + paramList);
this.set('fetchingResult', true);
this.$.translateAjax.generateRequest();
}
/**
* Returns a string with all the query parameters composed together. This
* also serializes the rapid response rules provided.
* @param {!Object} params The params to combine.
* @returns {string} The params collapsed together.
* @private
*/
createBodyValue_(params) {
// Add the key value body parts.
var bodyParts = [];
for (var param in params) {
var value = window.encodeURIComponent(params[param]);
bodyParts.push(param + "=" + value);
}
// Add the rapid response rules.
for (var i = 0; i < this.rules_.length; ++i) {
var rule = this.rules_[i];
var value =
'src_lang: "' + this.model_.source_language.code + '" ' +
'trg_lang: "' + this.model_.target_language.code + '" ' +
'source: "' + rule['source'] + '" ' +
'bad_translations: "' + rule.bad_translations + '" ' +
'good_translations: "' + rule.good_translations + '" ' +
'attention_threshold: ' + rule.attention_threshold;
bodyParts.push('rule=' + window.encodeURIComponent(value));
}
// Combine everything together.
return bodyParts.join('&');
}
/**
* Adds the translation response to the list of results.
* @param {!Event} event The event object from the `response` event. This is
* required to access the current response, as there are timing issues when
* accessing the latest response with iron-ajax's `last-response` attribute.
* @private
*/
handleTranslationResponse_(event) {
this.set('fetchingResult', false);
this.push('results', {
response: event.detail.response,
query: this.query_,
model: this.model_,
});
}
/**
* Adds a new rapid response rule to be filled out.
* @private
*/
addRule_() {
this.push('rules_', {
source: '',
bad_translations: '',
good_translations: '',
attention_threshold: 0.9,
});
}
/**
* Deletes a rapid response rule.
* @param {Event} e The event in the dom repeat template element.
* @private
*/
deleteRule_(e) {
let model = e.model;
this.splice('rules_', model.index, 1);
}
}
customElements.define(ExploreView.is, ExploreView);
================================================
FILE: tensor2tensor/insights/polymer/graph_visualization/graph-visualization.html
================================================
Token: [[currentName]]
Token Probability: [[currentProbability]]
Total Probability: [[currentTotalProbability]]
Score: [[score]]
================================================
FILE: tensor2tensor/insights/polymer/graph_visualization/graph-visualization.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` Presents a beam search decoding graph.
*
* The Beam Search decoding graph visualizes the entire search space of a
* sequence generation model. Each layer in the graph displays a decoding step
* with nodes in that layer representing generated candidates. If supported by
* the backend server, the graph can enter interactive mode where candidates can
* be selected for each generation step.
*
*
* ### Usage
*
*
*/
class GraphVisualization extends Polymer.Element {
constructor() {
super();
/**
* @private
*/
this.svg_ = undefined;
/**
* @private
*/
this.vis_ = undefined;
/**
* @type {!TreeNode}
* @private
*/
this.rootTree_ = {
name: '',
localProbability: 0,
cumalitiveProbability: 0,
score: 0,
attention: [],
children: [],
};
/**
* @type {!InteractiveNode}
* @private
*/
this.interactiveRoot_ = {
id: this.nodeId_,
stepIndex: 0,
candidate: {
label: '',
label_id: 1,
log_probability: 0,
total_log_probability: 0,
score: 0,
parent_id: 0
},
children: [],
};
/**
* @type {Array}
* @private
*/
this.selectedNodes_ = [];
/**
* @private
*/
this.stepNodes_ = [];
/**
* Metadata for navigating nodes.
* @private
*/
this.nodeId_ = 0;
/**
* D3.js helper object.
* @private
*/
this.partition_ = undefined;
/**
* D3.js helper object.
* @private
*/
this.zoom_ = undefined;
/**
* D3.js DOM element.
* @private
*/
this.container_ = undefined;
}
/**
* @return {string} The component name.
*/
static get is() {
return 'graph-visualization';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {!SearchGraphVisualization}
*/
data: {
type: Object,
observer: 'dataUpdated_',
},
/**
* @type {!Model}
*/
model: {
type: Object,
},
/**
* @type {string}
*/
query: {
type: String,
},
/**
* @type {number}
*/
zoomDepth_: {
type: Number,
value: 20,
},
/**
* @type {!StartTranslationResponse}
*/
startResponse_: {
type: Object,
},
/**
* @type {!GenerateCandidateResponse}
*/
generateResponse_: {
type: Object,
},
};
}
/**
* @return {!Array} The component observers.
*/
static get observers() {
return [
'zoomDepthChanged_(zoomDepth_)',
];
}
/**
* Sets the default zoom depth.
* @override
*/
ready() {
super.ready();
this.set('zoomDepth_', 20);
this.set('stepMode', 'view');
}
/**
* Sets the zoom state based on the updated depth.
* @param {number} zoomDepth the zoom depth.
* @private
*/
zoomDepthChanged_(zoomDepth) {
if (!this.svg_) {
return;
}
if (zoomDepth == 0) {
zoomDepth = 0.000001;
}
let transform = d3.zoomTransform(this.svg_.node()).scale(zoomDepth / 20.0);
this.vis_.attr("transform", transform);
}
/**
* Converts the NMT Graph JSON format to a nested tree heirachy and plots the
* tree as a collapsible tree visualization.
* @private
*/
dataUpdated_() {
// We need to determine two key nodes in the graph:
// Root: This is the node with no in links and some out links.
// Term: This is the terminal node with no out links and some in links.
//
// Our plot will associate token with actual nodes. For all nodes except
// the Term node, this will work fine since in the tree, each node is
// referenced only once as the head of an edge.
//
// The Term node however needs to be duplicated for each edge ending at it
// so that each instance can have a unique token associated with it.
// Step 1) Find Root and Term node indices so they can be refered to later.
var rootIndex = -1;
var nodes = this.data.node;
for (var i = 0; i < nodes.length && rootIndex == -1; ++i) {
var node = nodes[i];
if (node.in_edge_index.length == 0 && node.out_edge_index.length != 0) {
rootIndex = i;
}
}
// Step 2) Create the root node in the tree. The tree structure will have
// the following components:
// name: The display name of the node. This will be some token.
// localProbability: The per time step probability of this node.
// cumulativeProbability: The total probability of this path in the beam
// search.
// score: A final score for this path in the beam search. This is
// typically the cumulativeProbability with zero or more penalties.
// attention: The attention vector associated with this node transition.
// children: The list of children in the tree, which are themselves trees.
this.rootTree_ = {
name: '',
localProbability: 0,
cumalitiveProbability: 0,
score: 0,
attention: [],
children: [],
};
// Step3) Add each child and it's children recursively starting from the
// root node.
var rootNode = nodes[rootIndex];
var edges = this.data.edge;
for (var i = 0; i < rootNode.out_edge_index.length; ++i) {
// Get the edge.
var outEdge = edges[rootNode.out_edge_index[i]];
this.addChildToTree_(this.rootTree_, outEdge, nodes, edges);
}
this.propagateLabel_(this.rootTree_);
this.createSVG_();
this.plotTree_(this.rootTree_);
}
/**
* Forwards path labels from a node's child to the current node.
* @param {!TreeNode} node The node to annotate.
* @private
*/
propagateLabel_(node) {
var hasNBest = false;
var hasBeam = false;
var hasAlternative = false;
for (var i = 0; i < node.children.length; ++i) {
hasNBest = hasNBest || node.children[i].pathType == 'nbest';
hasBeam = hasBeam || node.children[i].pathType == 'beam';
hasAlternative = hasAlternative ||
node.children[i].pathType == 'alternative';
}
if (hasNBest) {
node.pathType = 'nbest';
} else if (hasBeam) {
node.pathType = 'beam';
} else if (hasAlternative) {
node.pathType = 'beam';
} else {
node.pathType = 'unknown';
}
}
/**
* Iterates through all the children in tree and adds them as children to the
* top level tree.
* @param {!TreeNode} tree The current node in the tree to update with
* children.
* @param {!BeamSearchEdge} currentEdge The edge going into tree.
* @param {!Array} nodes The list of all node objects.
* @param {!Array} edges The list of all edges between nodes.
* @private
*/
addChildToTree_(tree, currentEdge, nodes, edges) {
// The real edge information is nested in wonderfully named proto
// extensions. Extract the extension information appropriately.
var candidate = currentEdge.data;
// When the label for the new child is empty, we're at a terminal sink. So
// we ignore that node and instead label the parent.
if (candidate.label == '') {
tree.pathType = 'alternative';
return;
}
var node = nodes[currentEdge.target_index];
/**
* @type {TreeNode}
*/
var childTree = {
name: candidate.label,
attention: [],
localProbability: Math.pow(Math.E, candidate.log_probability),
cumalitiveProbability: Math.pow(Math.E, candidate.total_log_probability),
score: Math.pow(Math.E, candidate.score),
finished: currentEdge.completed || false,
children: [],
node: node,
edge: currentEdge,
pathType: 'unknown',
};
tree.children.push(childTree);
if (node.out_edge_index.length == 0) {
if (childTree.name == ' ') {
childTree.pathType = 'nbest';
} else if (childTree.name == '' || candidate.finished) {
childTree.pathType = 'alternative';
} else {
childTree.pathType = 'beam';
}
} else {
for (var i = 0; i < node.out_edge_index.length; ++i) {
// Get the edge.
var outEdge = edges[node.out_edge_index[i]];
this.addChildToTree_(childTree, outEdge, nodes, edges);
this.propagateLabel_(childTree);
}
}
}
/**
* Creates the initial SVG canvas and associated structures. This will remove
* all previous svg elements.
* @private
*/
createSVG_() {
// Create the margins, width, and height.
var maxWidth = 1600;
var maxHeight = 1600;
var margins = [20, 120, 20, 20];
var width = maxWidth - margins[1] - margins[3];
var height = maxHeight - margins[0] - margins[2];
// Use a d3 partition which will place each node based it's number of
// descendents with the highest ranked path along the top.
this.partition_ = d3.partition().size([height, width]).padding(1);
// Set the initial position of the root of the tree to be a half the height
// and on the left..
this.rootTree_.x0 = height / 2;
this.rootTree_.y0 = 0;
this.zoom_ = d3.zoom()
.scaleExtent([1, 10])
.on("zoom", zoomed.bind(this));
d3.select(this.$.chart).selectAll('.svg-container').remove();
// Embed the SVG to host the tree and rotate it so that horizontal matches
// the height of the canvas.
this.svg_ = d3.select(this.$.chart)
.append("div")
.classed("svg-container", true)
.append("svg")
.attr("height", "100%")
.attr("width", "100%")
.classed("svg-content-responsive", true)
.call(this.zoom_)
.on('dblclick.zoom', null)
.on('wheel.zoom', null);
/**
* Note: For reasons not understood, the javascript compiler can't figure
* out the type of _zoomDepth at this line, so we need to coerce it into
* being a number.
* @type {number}
*/
let zoomDepth = parseInt(this.zoomDepth_, 10);
let transform = d3.zoomTransform(this.svg_.node()).scale(zoomDepth / 20.0);
this.vis_ = this.svg_.append('g')
.attr("transform", transform);
// Ensure that the entire svg element can be used for panning.
this.vis_.append("rect")
.attr("width", maxWidth)
.attr("height", maxWidth)
.style("fill", "none")
.style("pointer-events", "all");
this.container_ = this.vis_.append("g");
// Apply the zoom transformation.
function zoomed() {
this.vis_.attr("transform",
d3.event.transform.scale(this.zoomDepth_ / 20.0));
}
}
/**
* Examines and plots all reachable nodes in the rootTree with respect to the
* given current root.
* @param {!TreeNode} root The current root node to focus on.
* @private
*/
plotTree_(root) {
// Create the hierarchy. We accumulate node values by just counting the
// number of elements, rather than placing a weight on each node..
var treeHierachy = d3.hierarchy(this.rootTree_)
.sum(function(d) {
return 1;
})
.sort(function(a, b) {
return a.data.score - b.data.score;
});
this.partition_(treeHierachy);
// Create an enter object where we can add both nodes and links.
var enter = this.container_.selectAll(".node")
.data(treeHierachy.descendants())
.enter();
// Add the nodes in four steps:
// 1) A general group element to hold all node portions.
// 2) A rectangle with no visible elements.
// 3) A circle for the node.
// 4) a text label.
var node = enter.append("g")
.attr("class", function(d) {
return "node" + (d.children ? " node--internal" : " node--leaf");
})
.attr("transform", function(d) {
return "translate(" + d.y0 + "," + d.x0 + ")";
})
.attr('id', function(d, i) { return "g-" + i; });
node.append("rect")
.attr("width", function(d) { return d.y1 - d.y0; })
.attr("height", 24);
node.append("circle")
.attr("r", 10)
.attr("transform", "translate(10, 10)");
node.append("text")
.attr("x", 24)
.attr("y", 13)
.text(function(d) { return d.data.name; });
// Add out links from each node to it's parent. We link two nodes using the
// bottom center of the circle so that the text label can be placed at
// approximately the vertical center of the circle. This gives a decent
// layout while also not hiding any text.
enter.append("path")
.attr("class", "link")
.attr("d", function(d) {
if (!d.parent) { return ""; }
// Pad the placement of the links just below the center. We have to
// use x0 and y0 for location due to partition, which doesn't create
// standard x/y fields.
var nodeX = d.x0 + 16;
var nodeY = d.y0 + 10;
var parentX = d.parent.x0 + 16;
var parentY = d.parent.y0 + 10;
return "M" + + nodeY + "," + nodeX +
"C" + (nodeY + parentY) / 2 + "," + nodeX + " " +
(nodeY + parentY) / 2 + "," + parentX + " " +
parentY + "," + parentX;
})
.style('stroke', function(d) {
// Associate a different path color depend on the path type for the
// node.
if (d.data.pathType == 'unknown')
return '#222';
if (d.data.pathType == 'nbest')
return '#66ff33';
if (d.data.pathType == 'beam')
return '#ccc';
if (d.data.pathType == 'alternative')
return '#ff3300';
});
// Setup hover events on each node to place focus and highligh on the node
// being hovered over. We do this by adding opacity to all other nodes.
var nodes = this.container_.selectAll(".node");
node.on('mouseover', function(d, i) {
nodes.classed('fade', function(d, j) {
return i != j;
});
d3.select(this).classed('hover', true);
this.set('currentName', d.data.name);
this.set(
'currentProbability', this.displayNumber(d.data.localProbability));
this.set(
'currentTotalProbability',
this.displayNumber(d.data.cumalitiveProbability));
this.set('score', this.displayNumber(d.data.score));
}.bind(this))
.on('mouseout', function(d, i) {
nodes.classed("fade", false);
d3.select(this).classed("hover", false);
});
}
/**
* Resets the pan and zoom state.
* @private
*/
reset_() {
if (!this.svg_) {
return;
}
this.svg_.call(this.zoom_.transform, d3.zoomIdentity);
this.set('zoomDepth_', 20);
}
/**
* Returns the number value with only 2 significant digits.
* @param {number} value The value to present.
* @return {string} value with just two significant digits.
*/
displayNumber(value) {
return value.toFixed(2);
}
/**
* Enters step by step decoding mode.
* @private
*/
startStepMode_() {
this.set('stepMode', 'edit');
this.startTranslation_();
}
/**
* Exits step by step decoding mode.
* @private
*/
exitStepMode_() {
this.set('stepMode', 'view');
this.dataUpdated_();
}
/**
* Begins step by step decoding with the current model and query.
* @private
*/
startTranslation_() {
this.set('startBody', JSON.stringify({
model_id: {
language_pair: {
source_language: this.model.source_language.code,
target_language: this.model.target_language.code,
},
name: this.model.id,
},
input: this.query,
}));
this.$.startAjax.generateRequest();
}
/**
* Handles a start error.
* @private
*/
handleStartError_() {
console.log("failed");
}
/**
* Initializes the step by step decoding graph with the root note and makes
* the first generation step.
* @private
*/
handleStartResponse_() {
// Reset the node state and create the root of the tree. Later candidates
// that are returned from the generation call will be added.
this.nodeId_ = 0;
this.interactiveRoot_ = {
id: this.nodeId_,
stepIndex: 0,
candidate: {
label: '',
label_id: 1,
log_probability: 0,
total_log_probability: 0,
score: 0,
parent_id: 0
},
children: [],
};
this.nodeId_++;
// Track which nodes are active and available as inputs to the next
// generation step. These will be updated with the candidates they
// generate.
this.selectedNodes_ = [this.interactiveRoot_];
// Redraw the entire plot with an interactive version.
this.createSVG_();
this.drawInteractiveTree_(this.interactiveRoot_);
// Make the first generation request.
this.step_(true);
}
/**
* Handles a generate ajax error.
* @private
*/
handleGenerateError_() {
console.log("generate failed");
}
/**
* Processes the returned candidates and adds them to the graph.
* @private
*/
handleGenerateResponse_() {
// Add the candidates returned and tag them with unique identifiers so we
// can ensure later generation steps don't try to include candidates that
// can't be proccesed any more (we can only use candidates from the most
// recent generation step as input due to limitations in the remote
// decoder).
let stepIndex = 0;
let newlySelectedNodes = [];
this.stepNodes_ = [];
for (var i = 0; i < this.generateResponse_.candidate_list.length; ++i) {
let selectedNode = this.selectedNodes_[i];
let candidateList = this.generateResponse_.candidate_list[i];
for (var j = 0; j < candidateList.candidate.length && j < 5; ++j) {
let candidate = candidateList.candidate[j];
// Tag the parent id so that the next generate call knows what network
// states to maintain.
candidate.parent_id = i;
let newNode = {
id: this.nodeId_,
stepIndex: stepIndex,
candidate: candidate,
children: [],
};
this.nodeId_++;
stepIndex++;
this.stepNodes_.push(newNode);
selectedNode.children.push(newNode);
// Select the first candidate.
if (j === 0) {
newNode.selected = true;
newlySelectedNodes.push(newNode);
}
}
}
this.selectedNodes_ = newlySelectedNodes;
// Reset the graph.
this.createSVG_();
this.drawInteractiveTree_(this.interactiveRoot_);
}
/**
* Draws the interactive tree.
* @param {InteractiveNode} rootNode The root node to draw out.
* @private
*/
drawInteractiveTree_(rootNode) {
let treeHierachy = d3.hierarchy(rootNode)
.sum(function(d) {
return 1;
})
.sort(function(a, b) {
return b.data.candidate.total_log_probability -
a.data.candidate.total_log_probability;
});
this.partition_(treeHierachy);
// Create an enter object where we can add both nodes and links.
var enter = this.container_.selectAll(".node")
.data(treeHierachy.descendants())
.enter();
// Add the nodes in four steps:
// 1) A general group element to hold all node portions.
// 2) A rectangle with no visible elements.
// 3) A circle for the node.
// 4) a text label.
var node = enter.append("g")
.attr("class", function(d) {
return "node" +
(d.children ? " node--internal" : " node--leaf") +
(d.data.selected ? " selected" : "");
})
.attr("transform", function(d) {
return "translate(" + d.y0 + "," + d.x0 + ")";
})
.attr('id', function(d, i) { return "g-" + i; });
node.append("rect")
.attr("width", function(d) { return d.y1 - d.y0; })
.attr("height", 24);
node.append("circle")
.attr("r", 10)
.attr("transform", "translate(10, 10)");
node.append("text")
.attr("x", 24)
.attr("y", 13)
.text(function(d) { return d.data.candidate.label; });
// Add out links from each node to it's parent. We link two nodes using the
// bottom center of the circle so that the text label can be placed at
// approximately the vertical center of the circle. This gives a decent
// layout while also not hiding any text.
enter.append("path")
.attr("class", "link")
.attr("d", function(d) {
if (!d.parent) { return ""; }
// Pad the placement of the links just below the center. We have to
// use x0 and y0 for location due to partition, which doesn't create
// standard x/y fields.
var nodeX = d.x0 + 16;
var nodeY = d.y0 + 10;
var parentX = d.parent.x0 + 16;
var parentY = d.parent.y0 + 10;
return "M" + + nodeY + "," + nodeX +
"C" + (nodeY + parentY) / 2 + "," + nodeX + " " +
(nodeY + parentY) / 2 + "," + parentX + " " +
parentY + "," + parentX;
})
.style('stroke', '#ccc');
node.on('mouseover', function(d, i) {
this.set('currentName', d.data.candidate.label);
this.set(
'currentProbability',
this.displayNumber(Math.exp(d.data.candidate.log_probability)));
this.set(
'currentTotalProbability',
this.displayNumber(Math.exp(d.data.candidate.total_log_probability)));
this.set('score', this.displayNumber(Math.exp(d.data.candidate.score)));
}.bind(this));
// Store a local pointer to stepNodes and selectedNodes so that the click
// handler can access them without having to replace the 'this' pointer.
// The click handler needs the default 'this' handler to update the state of
// the clicked upon node.
let stepNodes = this.stepNodes_;
let selectedNodes = this.selectedNodes_;
node.on('click', function(d, i) {
// Ignore nodes that fall out of bounds.
let stepIndex = d.data.stepIndex;
if (stepIndex >= stepNodes.length) {
return;
}
// Ignore nodes that are from different steps.
let node = stepNodes[stepIndex];
if (node.id != d.data.id) {
return;
}
// Update the selected state of the node and either add it to the selected
// list or remove it.
if (!node.selected) {
node.selected = true;
selectedNodes.push(node);
} else {
node.selected = false;
selectedNodes.splice(selectedNodes.indexOf(node), 1);
}
d3.select(this).classed('selected', node.selected);
});
}
/**
* Make one generation step with the candidates in the current selectedNodes
* list. If no nodes are selected, this silently does nothing.
* @param {boolean=} opt_skipNext If true, skips the next step during
* generation.
* @private
*/
step_(opt_skipNext) {
// Running generate without any nodes can put the decoder into a bad state
// and make the session unusable, so for now, silently skip this case.
if (this.selectedNodes_.length == 0) {
console.log("Skipping empty step.");
return;
}
this.set('generateParams', {
skip_next: opt_skipNext || false,
});
this.set('generateBody', JSON.stringify({
model_id: {
language_pair: {
source_language: this.model.source_language.code,
target_language: this.model.target_language.code,
},
name: this.model.id,
},
session_id: this.startResponse_.session_id,
candidate: this.selectedNodes_.map(function(node) {
return node.candidate;
}),
}));
this.$.generateAjax.generateRequest();
}
}
customElements.define(GraphVisualization.is, GraphVisualization);
================================================
FILE: tensor2tensor/insights/polymer/index.html
================================================
NMT Research Frontend
================================================
FILE: tensor2tensor/insights/polymer/insights_app/insights-app.html
================================================
Debug Frontend
Explore
================================================
FILE: tensor2tensor/insights/polymer/insights_app/insights-app.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` Manages the views of the NMT Insights App.
*
* ### Usage
*
*
*
*/
class InsightsApp extends Polymer.Element {
/**
* @return {string} The component name.
*/
static get is() {
return 'insights-app';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {string}
*/
page: {
type: String,
reflectToAttribute: true,
},
};
}
/**
* @return {!Array} The component observers.
*/
static get observers() {
return [
'routePageChanged_(routeData.page)',
];
}
/**
* Updates the page field if page exists or uses a default value.
* @param {?string} page The current page name being viewed.
* @private
*/
routePageChanged_(page) {
if (page == this.page) {
return;
}
this.page = page || 'explore';
this.set('routeData.page', this.page);
// Refresh the now selected page in case it needs new data on a new view.
let currentPage = this.get('currentPage');
if (currentPage) {
currentPage.refresh();
}
}
}
customElements.define(InsightsApp.is, InsightsApp);
================================================
FILE: tensor2tensor/insights/polymer/language_selector/language-selector-content.html
================================================
[[item.name]] ([[item.code]])
================================================
FILE: tensor2tensor/insights/polymer/language_selector/language-selector-content.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` provides menu content for language selection.
*
* The content provides a search bar that will filter available languages by any
* language name or code that has the query text as a substring.
*
* By default, this will auto select a provided language with language code
* 'en'.
*
* ### Usage
*
*
*
*/
class LanguageSelectorContent extends Polymer.Element {
/**
* @return {string} The component name.
*/
static get is() {
return 'language-selector-content';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {?Array}
*/
languages: {
type: Array,
observer: 'languagesUpdated_',
},
/**
* @type {!Language}
*/
value: {
type: Object,
notify: true,
},
/**
* @type {string}
*/
defaultCode: {
type: String,
value: 'en',
}
};
}
/**
* @return {!Array} The component observers.
*/
static get observers() {
return [
'selectDefault_(languages, renderedItemCount)',
'filterUpdated_(filter)',
];
}
/**
* Selects the language in the drop down.
* @param {Language} language The language to pre-select.
* @public
*/
forceSelection(language) {
this.set('filter', '');
for (var i = 0; i < this.languages.length; ++i) {
if (this.languages[i].code == language.code) {
this.set('value', this.languages[i]);
this.updateSelected_(Polymer.dom(this.$.items).children[i]);
return;
}
}
}
/**
* Updates the internal languages and resets selection.
* @param {?Array} newLanguages The new language list.
* @private
*/
languagesUpdated_(newLanguages) {
if (newLanguages) {
for (var i = 0; i < newLanguages.length; ++i) {
newLanguages[i].hidden = false;
}
}
this.set('filter', '');
this.set('selected', undefined);
}
/**
* Selects the default language if one can be found after all languages have
* been rendered in the menu.
* @param {?Array} languages The languages
* @param {number} renderedItemCount The number of languages rendered.
* @private
*/
selectDefault_(languages, renderedItemCount) {
if (this.get('selected') || !languages ||
languages.length != renderedItemCount) {
return;
}
this.$.languageList.render();
if (this.value) {
for (var i = 0; i < languages.length; ++i) {
if (languages[i].code == this.value.code) {
this.updateSelected_(Polymer.dom(this.$.items).children[i]);
return;
}
}
}
let defaultCode = this.get('defaultCode');
for (var i = 0; i < languages.length; ++i) {
if (languages[i].code == defaultCode || languages.length == 1) {
this.set('value', languages[i]);
this.updateSelected_(Polymer.dom(this.$.items).children[i]);
return;
}
}
}
/**
* Selects the rendered language if only one is visible given the current
* search filter.
* @private
*/
enterPressed_() {
let visibleLanguagesIndices = [];
for (var i = 0; i < this.languages.length; ++i) {
if (!this.languages[i].hidden) {
visibleLanguagesIndices.push(i);
}
}
if (visibleLanguagesIndices.length == 1) {
this.set('value', this.languages[visibleLanguagesIndices[0]]);
this.updateSelected_(Polymer.dom(this.$.items).children[0]);
}
}
/**
* Sets the hidden state of languages given the current filter.
* @param {string} newFilter The new filter to match languages against.
* @private
*/
filterUpdated_(newFilter) {
if (!this.get('languages')) {
return;
}
let filter = newFilter.toLowerCase();
for (var i = 0; i < this.languages.length; ++i) {
let hidden = !this.languageMatchesQuery_(this.languages[i], filter);
this.set('languages.' + i + '.hidden', hidden);
}
}
/**
* Returns true if the language is visible.
* @param {!Language} language The language being evaluated.
* @return {boolean} True if visible.
* @private
*/
isShown_(language) {
return !language.hidden;
}
/**
* Returns true if the language matches the filter.
* @param {!Language} language The language being evaluated.
* @param {string} filter The filter to compare against.
* @return {boolean} True if language matches filter.
* @private
*/
languageMatchesQuery_(language, filter) {
let languageName = language.name.toLowerCase();
return filter == '' || languageName.indexOf(filter) >= 0 ||
language.code.indexOf(filter) >= 0;
}
/**
* Selects the tapped element and updates the value with the corresponding
* language value.
* @param {!EventTarget} e The tap event.
* @private
*/
select_(e) {
let language = this.$.languageList.itemForElement(e.target);
this.set('value', language);
this.updateSelected_(e.target);
}
/**
* Updates the selection with the given element.
* @param {!Element} ele The selected dom element.
* @private
*/
updateSelected_(ele) {
let oldSelection = this.get('selected');
if (oldSelection) {
this.dispatchEvent(new CustomEvent('iron-deselect', {
bubbles: true,
composed: true,
detail: {
item: oldSelection,
},
}));
}
this.set('selected', ele);
this.dispatchEvent(new CustomEvent('iron-select', {
bubbles: true,
composed: true,
detail: {
item: ele,
},
}));
}
}
customElements.define(LanguageSelectorContent.is, LanguageSelectorContent);
================================================
FILE: tensor2tensor/insights/polymer/language_selector/language-selector.html
================================================
================================================
FILE: tensor2tensor/insights/polymer/language_selector/language-selector.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` provides a searchable dropdown of languages.
*
* The dropdown will present the selected language's Name. When opened, the
* search bar will filter available languages by any language name or code that
* has the query text as a substring.
*
* By default, this will auto select a provided language with language code
* 'en'.
*
* ### Usage
*
*
*
*/
class LanguageSelector extends Polymer.Element {
/**
* @return {string} The component name.
*/
static get is() {
return 'language-selector';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {string}
*/
label: {
type: String,
},
/**
* @type {?Array}
*/
languages: {
type: Array,
},
/**
* @type {!Language}
*/
value: {
type: Object,
notify: true,
},
/**
* @type {string}
*/
defaultCode: {
type: String,
value: 'en',
},
};
}
/**
* Selects the language in the drop down.
* @param {Language} language The language to pre-select.
* @public
*/
forceSelection(language) {
this.$.selector.forceSelection(language);
}
}
customElements.define(LanguageSelector.is, LanguageSelector);
================================================
FILE: tensor2tensor/insights/polymer/processing_visualization/processing-visualization.html
================================================
Source Processing
Target Processing
================================================
FILE: tensor2tensor/insights/polymer/processing_visualization/processing-visualization.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` summarises pre/post processing steps.
*
* This element presents the pre-processing segmentation steps and
* post-processing de-segmentation and rewrite steps that are applied to a
* translation query.
*
* ### Usage
*
*
*/
class ProcessingVisualization extends Polymer.Element {
/**
* @return {string} The component name.
*/
static get is() {
return 'processing-visualization';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {!QueryProcessingVisualization}
*/
data: {
type: Object,
},
};
}
}
customElements.define(ProcessingVisualization.is, ProcessingVisualization);
================================================
FILE: tensor2tensor/insights/polymer/query_card/query-card.html
================================================
================================================
FILE: tensor2tensor/insights/polymer/query_card/query-card.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` presents a material card for selecting a supported mdoel.
*
* This will fetch a set of supported models for debugging and provide three
* selectors:
* - Source Language
* - Target Language
* - Model
* Once all three have been populated, it will emit a `Model` object through
* `model`.
*
* ### Usage
*
*
* Custom InputField
*
*/
class QueryCard extends Polymer.Element {
constructor() {
super();
/**
* A general mapping from language code to the language objects.
* @type {!Object}
* @private
*/
this.languageToNameMap_ = {};
/**
* A nested mapping of languages to a list of models.
* @type {!Object>>>}
* @private
*/
this.languagePairToModelMap_ = {};
}
/**
* @return {string} The component name.
*/
static get is() {
return 'query-card';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {!Object}
*/
route: {
type: String,
},
/**
* @type {!Object}
*/
subRoute: {
type: String,
notify: true,
},
/**
* @type {?Model}
*/
model: {
type: Object,
notify: true,
},
/**
* @type {string}
*/
url: {
type: String,
},
/**
* @type {?Language}
*/
sourceLanguage_: {
type: Object,
},
/**
* @type {?Language}
*/
targetLanguage_: {
type: Object,
},
/**
* @type {string}
*/
defaultModelId: {
type: String,
value: 'prod',
}
};
}
/**
* @return {!Array} The component observers.
*/
static get observers() {
return [
'routeActiveUpdated_(routeActive)',
'modelsUpdated_(modelConfigurations)',
'sourceLanguagesUpdated_(sourceLanguages, routeData)',
'targetLanguagesUpdated_(targetLanguages, routeData)',
'sourceLanguageUpdated_(sourceLanguage_)',
'targetLanguageUpdated_(targetLanguage_)',
'modelListUpdated_(modelList, routeData)',
'modelUpdated_(model)',
];
}
/**
* Resets the route data if the route is inactive.
* @param {boolean} routeActive The active state of the route.
* @private
*/
routeActiveUpdated_(routeActive) {
if (!routeActive) {
this.set('routeData', {});
}
}
/**
* Sets the sourceLanguage if a new source language matches the route
* path or marks it as undefined.
* @param {Array} sourceLanguages A list of source languages.
* @param {{sourceLanguage: string}} routeData The current route paths.
* @private
*/
sourceLanguagesUpdated_(sourceLanguages, routeData) {
if (this.routeActive && sourceLanguages) {
for (var i = 0; i < sourceLanguages.length; ++i) {
if (routeData.sourceLanguage == sourceLanguages[i].code) {
this.$.sourceSelector.forceSelection(sourceLanguages[i]);
return;
}
}
}
}
/**
* Selects the available target language list based on the new selected source
* language.
* @param {Language} sourceLanguage The selected source language index.
* @private
*/
sourceLanguageUpdated_(sourceLanguage) {
if (sourceLanguage == undefined) {
this.set('targetLanguages', []);
return;
}
this.set('routeData.sourceLanguage', sourceLanguage.code);
var targetLanguages = [];
for (var key in this.languagePairToModelMap_[sourceLanguage.code]) {
targetLanguages.push(this.languageToNameMap_[key]);
}
targetLanguages.sort(sort_);
this.set('targetLanguage', undefined);
this.set('targetLanguages', targetLanguages);
}
/**
* Sets the targetLanguage if a new target language matches the route
* path or marks it as undefined.
* @param {Array} targetLanguages A list of target languages.
* @param {{targetLanguage: string}} routeData The current route paths.
* @private
*/
targetLanguagesUpdated_(targetLanguages, routeData) {
if (this.routeActive && targetLanguages) {
for (var i = 0; i < targetLanguages.length; ++i) {
if (routeData.targetLanguage == targetLanguages[i].code) {
this.$.targetSelector.forceSelection(targetLanguages[i]);
return;
}
}
}
}
/**
* Selects the available model list based on the new selected target
* language.
* @param {Language} targetLanguage The selected target language index.
* @private
*/
targetLanguageUpdated_(targetLanguage) {
this.set('model', undefined);
if (targetLanguage == undefined) {
this.set('modelList', []);
return;
}
let sourceLanguage = this.sourceLanguage_;
this.set('routeData.targetLanguage', targetLanguage.code);
var models = [];
var targetLanguageMap = this.languagePairToModelMap_[sourceLanguage.code];
for (var key in targetLanguageMap[targetLanguage.code]) {
models.push(targetLanguageMap[targetLanguage.code][key]);
}
this.set('modelList', models);
}
/**
* Sets the modelIndex if a new model matches the route path or marks it as
* undefined.
* @param {?Array} modelList A list of models.
* @param {{modelId: string}} routeData The current route paths.
* @private
*/
modelListUpdated_(modelList, routeData) {
if (this.routeActive && modelList) {
for (var i = 0; i < modelList.length; ++i) {
if (routeData.modelId == modelList[i].id) {
this.set('model', modelList[i]);
return;
}
}
}
if (modelList && modelList.length >= 1) {
// Chose the default model if it exists, otherwise choose the first entry.
// This ensures that the ordering of models does't impact the default
// selection.
for (var i = 0; i < modelList.length; ++i) {
if (this.defaultModelId == modelList[i].id) {
this.set('model', modelList[i]);
return;
}
}
this.set('model', modelList[0]);
}
}
/**
* Updates the selected model with the current model index.
* @param {?Model} model The current selected model index.
* @private
*/
modelUpdated_(model) {
if (!model) {
return;
}
this.set('routeData.modelId', this.model.id);
}
/**
* Updates the set of available language sets and models.
* @param {{configuration: !Array}} modelConfigurations A list of
* models.
* @private
*/
modelsUpdated_(modelConfigurations) {
var models = modelConfigurations.configuration;
this.languageToNameMap_ = {};
this.languagePairToModelMap_ = {};
for (var i = 0; i < models.length; ++i) {
let model = models[i];
// Extract the language codes and store the code to language mappings.
var source_language = model.source_language.code;
this.languageToNameMap_[source_language] = model.source_language;
var target_language = model.target_language.code;
this.languageToNameMap_[target_language] = model.target_language;
// Create the first level nested map, from source languages to target
// language maps.
var targetLanguageMap;
if (source_language in this.languagePairToModelMap_) {
targetLanguageMap = this.languagePairToModelMap_[source_language];
} else {
targetLanguageMap = {};
this.languagePairToModelMap_[source_language] = targetLanguageMap;
}
// Create the second level nested map, from target languages to model
// maps.
var model_map;
if (target_language in targetLanguageMap) {
model_map = targetLanguageMap[target_language];
} else {
model_map = {};
targetLanguageMap[target_language] = model_map;
}
// Store the mapping from a model id to a model.
model_map[model.id] = model;
}
// Prepare the initial set of available source languages.
var sourceLanguageList = [];
for (var key in this.languagePairToModelMap_) {
sourceLanguageList.push(this.languageToNameMap_[key]);
}
sourceLanguageList.sort(sort_);
this.set('sourceLanguages', sourceLanguageList);
}
}
customElements.define(QueryCard.is, QueryCard);
/**
* Returns the ordering of two language's based on their name.
* @param {!Language} a The first language to compare.
* @param {!Language} b The second language to compare.
* @return {number} Negative if a comes before b.
*/
function sort_(a, b) {
if (a.name != b.name) {
return a.name < b.name ? -1 : 1;
}
return 0;
}
================================================
FILE: tensor2tensor/insights/polymer/tensor2tensor.html
================================================
NMT Research Frontend
================================================
FILE: tensor2tensor/insights/polymer/translation_result/translation-result.html
================================================
================================================
FILE: tensor2tensor/insights/polymer/translation_result/translation-result.js
================================================
/**
* @license
* Copyright 2018 The Tensor2Tensor Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* `` Presents zero or more visualization of a translation.
*
* This inspects the set of visualization fields provided and triggers the
* corresponding visualization component in the set of available views in tabbed
* layout.
*
* ### Usage
*
*
*
*/
class TranslationResult extends Polymer.Element {
/**
* @return {string} The component name.
*/
static get is() {
return 'translation-result';
}
/**
* @return {!Object} The component properties.
*/
static get properties() {
return {
/**
* @type {{
* response: {
* visualization_name: string,
* title: string,
* name: string,
* query_processing: ?Object,
* search_graph: ?Object,
* word_heat_map: ?Object,
* },
* model: !Model,
* query: string
* }}
*/
result: {
type: Object,
observer: 'resultUpdated_',
},
/**
* @type {string}
*/
view: {
type: String,
value: 'processing',
},
};
}
/**
* Sets internal data structures given the updated result.
* @private
*/
resultUpdated_() {
var response = this.result.response;
if (!response || !response.result || response.result.length == 0) {
return;
}
for (var i = 0; i < response.result.length; ++i) {
let visualizationResult = response.result[i];
// Dynamically create the visualization element based on the name field.
// This will enable multiple versions of the same visualization to be
// created later on when the data mapping is generalized.
let analysisEle = document.createElement(
visualizationResult.visualization_name + '-visualization');
// Set the generic attributes.
analysisEle.name = visualizationResult.name;
analysisEle.model = this.result.model;
analysisEle.query = this.result.query;
// Set the visualization specific data attribute.
// TODO(kstevens): Cleanup by setting visualization_name the same as the
// protobuffer field names so we don't need this mapping.
if (visualizationResult.visualization_name == 'processing') {
analysisEle.data = visualizationResult.query_processing;
} else if (visualizationResult.visualization_name == 'attention') {
analysisEle.data = visualizationResult.word_heat_map;
} else if (visualizationResult.visualization_name == 'graph') {
analysisEle.data = visualizationResult.search_graph;
}
Polymer.dom(this.$.view).appendChild(analysisEle);
}
// Don't make assumptions about which visualizations are available. Instead
// preselect the initial view based on data.
this.set('view', response.result[0].name);
}
}
customElements.define(TranslationResult.is, TranslationResult);
================================================
FILE: tensor2tensor/insights/query_processor.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A base class for all query processing classes."""
class QueryProcessor(object):
"""Base class for any class that wants to process sequence queries.
QueryProcessor classes are expected to convert a string query to a series of
visualization structures.
TODO(kstevens): Define how the visualization structures should look once the
protos are in better shape.
"""
def process(self, query):
"""Returns the generated visualizations for query.
Args:
query: The string input
Returns:
A dictionary with one key: 'result' that maps to a list of visualization
objects.
"""
del query
return {"result": []}
================================================
FILE: tensor2tensor/insights/server.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A GUnicorn + Flask Debug Frontend for Transformer models."""
import json
from flask import Flask
from flask import jsonify
from flask import request
from flask import send_from_directory
from flask.json import JSONEncoder
from gunicorn.app.base import BaseApplication
from gunicorn.six import iteritems
import numpy as np
from tensor2tensor.insights import transformer_model
import tensorflow.compat.v1 as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("configuration", "",
"A JSON InsightConfiguration message that configures which "
"models to run in the insight frontend.")
flags.DEFINE_string("static_path", "",
"Path to static javascript and html files to serve.")
_NUMPY_INT_DTYPES = [
np.int8, np.int16, np.int32, np.int64
]
_NUMPY_FP_DTYPES = [
np.float16, np.float32, np.float64
]
class NumpySerializationFix(JSONEncoder):
"""json module cannot serialize numpy datatypes, reinterpret them first"""
def default(self, obj):
obj_type = type(obj)
if obj_type in _NUMPY_INT_DTYPES:
return int(obj)
if obj_type in _NUMPY_FP_DTYPES:
return float(obj)
return json.JSONEncoder.default(self, obj)
class DebugFrontendApplication(BaseApplication):
"""A local custom application for GUnicorns.
This custom application enables us to run with a custom main that parses
tensorflow ops and does some internal setup prior to processing queries. The
underlying app registered instances of this class will be forked.
"""
def __init__(self, app, options=None):
"""Creates the GUnicorn application.
Args:
app: A Flask application that will process requests.
options: A dict of GUnicorn options.
"""
self.options = options or {}
self.application = app
super(DebugFrontendApplication, self).__init__()
def load_config(self):
"""Loads the configuration."""
config = dict([(key, value) for key, value in iteritems(self.options)
if key in self.cfg.settings and value is not None])
for key, value in iteritems(config):
self.cfg.set(key.lower(), value)
def load(self):
"""Loads the application.
Returns:
The Flask application.
"""
return self.application
def main(_):
# Create the models we support:
with open(FLAGS.configuration) as configuration_file:
configuration = json.load(configuration_file)
# Read in the set of query processors.
processors = {}
for processor_configuration in configuration["configuration"]:
key = (processor_configuration["source_language"],
processor_configuration["target_language"],
processor_configuration["label"])
processors[key] = transformer_model.TransformerModel(
processor_configuration)
# Read in the list of supported languages.
languages = {}
for language in configuration["language"]:
languages[language["code"]] = {
"code": language["code"],
"name": language["name"],
}
# Create flask to serve all paths starting with '/polymer' from the static
# path. This is to served non-vulcanized components.
app = Flask(
__name__.split(".")[0],
static_url_path="/polymer",
static_folder=FLAGS.static_path)
app.json_encoder = NumpySerializationFix
# Disable static file caching.
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0
@app.route("/api/language_list/")
def language_list(): # pylint: disable=unused-variable
"""Responds to /api/language_list with the supported languages.
Returns:
JSON for the languages.
"""
return jsonify({
"language": list(languages.values())
})
@app.route("/api/list_models/")
def list_models(): # pylint: disable=unused-variable
"""Responds to /api/list_models with the supported modes.
Returns:
JSON for the supported models.
"""
# pylint: disable=g-complex-comprehension
configuration_list = [{
"id": label,
"source_language": languages[source_code],
"target_language": languages[target_code],
} for source_code, target_code, label in processors]
return jsonify({
"configuration": configuration_list
})
@app.route("/debug", methods=["GET"])
def query(): # pylint: disable=unused-variable
"""Responds to /debug with processing results.
Returns:
JSON for the query's result.
"""
query = request.args.get("source")
source_language = request.args.get("sl")
target_language = request.args.get("tl")
model_name = request.args.get("id")
processor = processors[(source_language, target_language, model_name)]
return jsonify(processor.process(query))
# Catchall for all other paths. Any other path should get the basic index
# page, the polymer side will determine what view to show and what REST calls
# to make for data.
@app.route("/", defaults={"path": ""})
@app.route("/")
def root(path): # pylint: disable=unused-variable
"""Responds to all other non-static paths with index.html.
Args:
path: Unused path.
Returns:
The landing page html text.
"""
if (path == "index.js" or
path == "webcomponentsjs/webcomponents-lite.js"):
# Some vulcanizing methods bundle the javascript into a index.js file
# paired with index.html but leave two important webcomponents js files
# outside of the bundle. If requesting those special files, fetch them
# directly rather than from a /static sub-directory.
return send_from_directory(FLAGS.static_path, path)
# Everything else should redirect to the main landing page. Since we
# use a single page app, any initial url requests may include random
# paths (that don't start with /api or /static) which all should be
# served by the main landing page.
return send_from_directory(FLAGS.static_path, "index.html")
# Run the server.
tf.logging.info("############# READY ##################")
options = {
"bind": ":8010",
"timeout": 600,
"workers": 4,
"reload": True,
"spew": True,
"worker_class": "gevent",
}
DebugFrontendApplication(app, options).run()
if __name__ == "__main__":
tf.app.run()
================================================
FILE: tensor2tensor/insights/transformer_model.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A QueryProcessor using the Transformer framework."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import deque
import glob
import os
import shutil
import time
import numpy as np
from six.moves import range
from tensor2tensor.bin import t2t_trainer
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.insights import graph
from tensor2tensor.insights import query_processor
from tensor2tensor.utils import decoding
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import usr_dir
import tensorflow.compat.v1 as tf
from tensorflow.python import debug as tfdbg
flags = tf.flags
FLAGS = flags.FLAGS
def topk_watch_fn(feeds, fetches):
"""TFDBG watch function for transformer beam search nodes.
Args:
feeds: Unused. Required by tfdbg.
fetches: Unused. Required by tfdbg.
Returns:
a WatchOptions instance that will capture all beam search ops.
"""
del fetches, feeds
return tfdbg.WatchOptions(
node_name_regex_whitelist=
".*grow_(finished|alive)_(topk_scores|topk_seq).*",
debug_ops=["DebugIdentity"])
def seq_filter(datum, tensor):
"""TFDBG data directory filter for capturing topk_seq operation dumps.
Args:
datum: A datum to filter by node_name.
tensor: Unused. Required by tfdbg
Returns:
a true when datum should be returned.
"""
del tensor
return "topk_seq" in datum.node_name
def scores_filter(datum, tensor):
"""TFDBG data directory filter for capturing topk_scores operation dumps.
Args:
datum: A datum to filter by node_name.
tensor: Unused. Required by tfdbg
Returns:
a true when datum should be returned.
"""
del tensor
return "topk_scores" in datum.node_name
def sequence_key(sequence):
"""Returns a key for mapping sequence paths to graph vertices."""
return ":".join([str(s) for s in sequence])
class TransformerModel(query_processor.QueryProcessor):
"""A QueryProcessor using a trained Transformer model.
This processor supports the following visualizations:
- processing: Basic source and target text processing
- graph: A graph of the beam search process.
"""
def __init__(self, processor_configuration):
"""Creates the Transformer estimator.
Args:
processor_configuration: A ProcessorConfiguration protobuffer with the
transformer fields populated.
"""
# Do the pre-setup tensor2tensor requires for flags and configurations.
transformer_config = processor_configuration["transformer"]
FLAGS.output_dir = transformer_config["model_dir"]
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
data_dir = os.path.expanduser(transformer_config["data_dir"])
# Create the basic hyper parameters.
self.hparams = trainer_lib.create_hparams(
transformer_config["hparams_set"],
transformer_config["hparams"],
data_dir=data_dir,
problem_name=transformer_config["problem"])
decode_hp = decoding.decode_hparams()
decode_hp.add_hparam("shards", 1)
decode_hp.add_hparam("shard_id", 0)
# Create the estimator and final hyper parameters.
self.estimator = trainer_lib.create_estimator(
transformer_config["model"],
self.hparams,
t2t_trainer.create_run_config(self.hparams),
decode_hparams=decode_hp, use_tpu=False)
# Fetch the vocabulary and other helpful variables for decoding.
self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
self.const_array_size = 10000
# Prepare the Transformer's debug data directory.
run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
for run_dir in run_dirs:
shutil.rmtree(run_dir)
def process(self, query):
"""Returns the visualizations for query.
Args:
query: The query to process.
Returns:
A dictionary of results with processing and graph visualizations.
"""
tf.logging.info("Processing new query [%s]" %query)
# Create the new TFDBG hook directory.
hook_dir = "/tmp/t2t_server_dump/request_%d" %int(time.time())
os.makedirs(hook_dir)
hooks = [tfdbg.DumpingDebugHook(hook_dir, watch_fn=topk_watch_fn)]
# TODO(kstevens): This is extremely hacky and slow for responding to
# queries. Figure out a reasonable way to pre-load the model weights before
# forking and run queries through the estimator quickly.
def server_input_fn():
"""Generator that returns just the current query."""
for _ in range(1):
input_ids = self.source_vocab.encode(query)
input_ids.append(text_encoder.EOS_ID)
x = [1, 100, len(input_ids)] + input_ids
x += [0] * (self.const_array_size - len(x))
d = {
"inputs": np.array(x).astype(np.int32),
}
yield d
def input_fn():
"""Generator that returns just the current query."""
gen_fn = decoding.make_input_fn_from_generator(server_input_fn())
example = gen_fn()
# TODO(kstevens): Make this method public
# pylint: disable=protected-access
return decoding._interactive_input_tensor_to_features_dict(
example, self.hparams)
# Make the prediction for the current query.
result_iter = self.estimator.predict(input_fn, hooks=hooks)
result = None
for result in result_iter:
break
# Extract the beam search information by reading the dumped TFDBG event
# tensors. We first read and record the per step beam sequences then record
# the beam scores. Afterwards we align the two sets of values to create the
# full graph vertices and edges.
decoding_graph = graph.Graph()
run_dirs = sorted(glob.glob(os.path.join(hook_dir, "run_*")))
for run_dir in run_dirs:
# Record the different completed and active beam sequence ids.
alive_sequences = deque()
finished_sequences = deque()
# Make the root vertex since it always needs to exist.
decoding_graph.get_vertex(sequence_key([0]))
# Create the initial vertices and edges for the active and finished
# sequences. We uniquely define each vertex using it's full sequence path
# as a string to ensure there's no collisions when the same step has two
# instances of an output id.
dump_dir = tfdbg.DebugDumpDir(run_dir, validate=False)
seq_datums = dump_dir.find(predicate=seq_filter)
for seq_datum in seq_datums:
sequences = np.array(seq_datum.get_tensor()).astype(int)[0]
if "alive" in seq_datum.node_name:
alive_sequences.append(sequences)
if "finished" in seq_datum.node_name:
finished_sequences.append(sequences)
for sequence in sequences:
pieces = self.targets_vocab.decode_list(sequence)
index = sequence[-1]
if index == 0:
continue
parent = decoding_graph.get_vertex(sequence_key(sequence[:-1]))
current = decoding_graph.get_vertex(sequence_key(sequence))
edge = decoding_graph.add_edge(parent, current)
edge.data["label"] = pieces[-1]
edge.data["label_id"] = index
# Coerce the type to be a python bool. Numpy bools can't be easily
# converted to JSON.
edge.data["completed"] = bool(index == 1)
# Examine the score results and store the scores with the associated edges
# in the graph. We fetch the vertices (and relevant edges) by looking
# into the saved beam sequences stored above.
score_datums = dump_dir.find(predicate=scores_filter)
for score_datum in score_datums:
if "alive" in score_datum.node_name:
sequences = alive_sequences.popleft()
if "finished" in score_datum.node_name:
sequences = finished_sequences.popleft()
scores = np.array(score_datum.get_tensor()).astype(float)[0]
for i, score in enumerate(scores):
sequence = sequences[i]
if sequence[-1] == 0:
continue
vertex = decoding_graph.get_vertex(sequence_key(sequence))
edge = decoding_graph.edges[vertex.in_edges[0]]
edge.data["score"] = score
edge.data["log_probability"] = score
edge.data["total_log_probability"] = score
# Delete the hook dir to save disk space
shutil.rmtree(hook_dir)
# Create the graph visualization data structure.
graph_vis = {
"visualization_name": "graph",
"title": "Graph",
"name": "graph",
"search_graph": decoding_graph.to_dict(),
}
# Create the processing visualization data structure.
# TODO(kstevens): Make this method public
# pylint: disable=protected-access
output_ids = decoding._save_until_eos(result["outputs"].flatten(), False)
output_pieces = self.targets_vocab.decode_list(output_ids)
output_token = [{"text": piece} for piece in output_pieces]
output = self.targets_vocab.decode(output_ids)
source_steps = [{
"step_name": "Initial",
"segment": [{
"text": query
}],
}]
target_steps = [{
"step_name": "Initial",
"segment": output_token,
}, {
"step_name": "Final",
"segment": [{
"text": output
}],
}]
processing_vis = {
"visualization_name": "processing",
"title": "Processing",
"name": "processing",
"query_processing": {
"source_processing": source_steps,
"target_processing": target_steps,
},
}
return {
"result": [processing_vis, graph_vis],
}
================================================
FILE: tensor2tensor/layers/__init__.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: tensor2tensor/layers/area_attention.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for area attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_layers
import tensorflow.compat.v1 as tf
def lengths_to_area_mask(feature_length, length, max_area_size):
"""Generates a non-padding mask for areas based on lengths.
Args:
feature_length: a tensor of [batch_size]
length: the length of the batch
max_area_size: the maximum area size considered
Returns:
mask: a tensor in shape of [batch_size, num_areas]
"""
paddings = tf.cast(tf.expand_dims(
tf.logical_not(
tf.sequence_mask(feature_length, maxlen=length)), 2), tf.float32)
_, _, area_sum, _, _ = compute_area_features(paddings,
max_area_width=max_area_size)
mask = tf.squeeze(tf.logical_not(tf.cast(area_sum, tf.bool)), [2])
return mask
def _pool_one_shape(features_2d, area_width, area_height, batch_size,
width, height, depth, fn=tf.reduce_max, name=None):
"""Pools for an area in features_2d.
Args:
features_2d: a Tensor in a shape of [batch_size, height, width, depth].
area_width: the max width allowed for an area.
area_height: the max height allowed for an area.
batch_size: the batch size.
width: the width of the memory.
height: the height of the memory.
depth: the depth of the features.
fn: the TF function for the pooling.
name: the op name.
Returns:
pool_tensor: A Tensor of shape [batch_size, num_areas, depth]
"""
with tf.name_scope(name, default_name="pool_one_shape"):
images = []
for y_shift in range(area_height):
image_height = tf.maximum(height - area_height + 1 + y_shift, 0)
for x_shift in range(area_width):
image_width = tf.maximum(width - area_width + 1 + x_shift, 0)
area = features_2d[:, y_shift:image_height, x_shift:image_width, :]
flatten_area = tf.reshape(area, [batch_size, -1, depth, 1])
images.append(flatten_area)
image_tensor = tf.concat(images, axis=3)
max_tensor = fn(image_tensor, axis=3)
return max_tensor
def basic_pool(features, max_area_width, max_area_height=1, height=1,
fn=tf.reduce_max, name=None):
"""Pools for each area based on a given pooling function (fn).
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
fn: the TF function for the pooling.
name: the namescope.
Returns:
pool_results: A Tensor of shape [batch_size, num_areas, depth]
area_heights: A Tensor of shape [batch_size, num_areas, 1]
area_widths: A Tensor of shape [batch_size, num_areas, 1]
"""
with tf.name_scope(name, default_name="basic_pool"):
feature_shape = common_layers.shape_list(features)
batch_size = feature_shape[0]
length = feature_shape[-2]
depth = feature_shape[-1]
width = length // height
features_2d = tf.reshape(features, [batch_size, height, width, depth])
height_list = []
width_list = []
pool_list = []
size_tensor = tf.ones_like(features_2d[:, :, :, 0], dtype=tf.int32)
for area_height in range(max_area_height):
for area_width in range(max_area_width):
pool_tensor = _pool_one_shape(features_2d,
area_width=area_width + 1,
area_height=area_height + 1,
batch_size=batch_size,
width=width,
height=height,
depth=depth,
fn=fn)
pool_list.append(
tf.reshape(pool_tensor, [batch_size, -1, depth]))
height_list.append(
tf.reshape(
size_tensor[:, area_height:, area_width:] *\
(area_height + 1), [batch_size, -1]))
width_list.append(
tf.reshape(
size_tensor[:, area_height:, area_width:] *\
(area_width + 1), [batch_size, -1]))
pool_results = tf.concat(pool_list, axis=1)
area_heights = tf.expand_dims(tf.concat(height_list, axis=1), 2)
area_widths = tf.expand_dims(tf.concat(width_list, axis=1), 2)
return pool_results, area_heights, area_widths
def _compute_sum_image(features, max_area_width, max_area_height=1, height=1,
name=None):
"""Computes area sums for features.
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
name: the namescope.
Returns:
sum_image: A Tensor of shape [batch_size, num_areas, depth]
area_heights: A Tensor of shape [batch_size, num_areas, 1]
area_widths: A Tensor of shape [batch_size, num_areas, 1]
"""
with tf.name_scope(name, default_name="compute_sum_image"):
feature_shape = common_layers.shape_list(features)
batch_size = feature_shape[0]
length = feature_shape[-2]
depth = feature_shape[-1]
width = length // height
features_2d = tf.reshape(features, [batch_size, height, width, depth])
width_cum = tf.cumsum(features_2d, axis=-2, name="compute_integral_h")
integral_image = tf.cumsum(width_cum, axis=-3, name="compute_integral_v")
padded_image = tf.pad(
integral_image, [[0, 0], [1, 0], [1, 0], [0, 0]], constant_values=0)
height_list = []
width_list = []
dst_images = []
src_images_diag = []
src_images_h = []
src_images_v = []
size_tensor = tf.ones_like(padded_image[:, :, :, 0],
dtype=tf.int32)
for area_height in range(max_area_height):
for area_width in range(max_area_width):
dst_images.append(
tf.reshape(
padded_image[:, area_height + 1:, area_width + 1:, :],
[batch_size, -1, depth]))
src_images_diag.append(
tf.reshape(
padded_image[:, :-area_height - 1, :-area_width - 1, :],
[batch_size, -1, depth]))
src_images_h.append(
tf.reshape(
padded_image[:, area_height + 1:, :-area_width - 1, :],
[batch_size, -1, depth]))
src_images_v.append(
tf.reshape(
padded_image[:, :-area_height - 1, area_width + 1:, :],
[batch_size, -1, depth]))
height_list.append(
tf.reshape(
size_tensor[:, area_height + 1:, area_width + 1:] *\
(area_height + 1), [batch_size, -1]))
width_list.append(
tf.reshape(
size_tensor[:, area_height + 1:, area_width + 1:] *\
(area_width + 1), [batch_size, -1]))
sum_image = tf.subtract(
tf.concat(dst_images, axis=1) + tf.concat(src_images_diag, axis=1),
tf.concat(src_images_v, axis=1) + tf.concat(src_images_h, axis=1))
area_heights = tf.expand_dims(tf.concat(height_list, axis=1), 2)
area_widths = tf.expand_dims(tf.concat(width_list, axis=1), 2)
return sum_image, area_heights, area_widths
def compute_area_features(features, max_area_width, max_area_height=1, height=1,
epsilon=1e-6):
"""Computes features for each area.
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
epsilon: the epsilon added to the variance for computing standard deviation.
Returns:
area_mean: A Tensor of shape [batch_size, num_areas, depth]
area_std: A Tensor of shape [batch_size, num_areas, depth]
area_sum: A Tensor of shape [batch_size, num_areas, depth]
area_heights: A Tensor of shape [batch_size, num_areas, 1]
area_widths: A Tensor of shape [batch_size, num_areas, 1]
"""
with tf.name_scope("compute_area_features"):
tf.logging.info("area_attention compute_area_features: %d x %d",
max_area_height, max_area_width)
area_sum, area_heights, area_widths = _compute_sum_image(
features, max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
area_squared_sum, _, _ = _compute_sum_image(
tf.pow(features, 2), max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
sizes = tf.multiply(area_heights, area_widths)
float_area_sizes = tf.to_float(sizes)
area_mean = tf.div(area_sum, float_area_sizes)
s2_n = tf.div(area_squared_sum, float_area_sizes)
area_variance = tf.subtract(s2_n, tf.pow(area_mean, 2))
area_std = tf.sqrt(tf.abs(area_variance) + epsilon)
return area_mean, area_std, area_sum, area_heights, area_widths
def compute_area_key(features, max_area_width, max_area_height=1, height=1,
mode="mean", training=True, name=None):
"""Computes the key for each area.
Args:
features: a Tensor in a shape of [batch_size, height * width, depth].
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
height: the height of the image.
mode: whether to combine different area features or only use
the vector mean of each area, which can be "mean", "concat", "sum",
"sample_concat", and "sample_sum".
training: indicating if it is in the training mode.
name: the name for setting the variable scope.
Returns:
area_key: a Tensor in the shape of [batch_size, num_areas, depth]
"""
tf.logging.info("area_attention mode=%s", mode)
area_mean, area_std, _, area_heights, area_widths =\
compute_area_features(features, max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
if mode == "mean":
return area_mean
elif mode == "max":
area_max, _, _ = basic_pool(features, max_area_width=max_area_width,
max_area_height=max_area_height, height=height)
return area_max
elif mode == "sample":
if training:
area_mean += (area_std * tf.random_normal(tf.shape(area_std)))
return area_mean
with tf.variable_scope(
name, default_name="combine_area_features",
values=[area_mean, area_std, area_heights, area_widths]):
depth = common_layers.shape_list(area_mean)[-1]
height_embed = tf.nn.embedding_lookup(
params=tf.get_variable("area_height_emb",
[max_area_height, depth // 2]),
ids=area_heights[:, :, 0] - 1)
width_embed = tf.nn.embedding_lookup(
params=tf.get_variable("area_width_emb",
[max_area_width, depth // 2]),
ids=area_widths[:, :, 0] - 1)
size_embed = tf.concat([height_embed, width_embed], -1)
if mode == "concat":
feature_concat = tf.concat([area_mean, area_std, size_embed], -1)
elif mode == "max_concat":
area_max, _, _ = basic_pool(features, max_area_width=max_area_width,
max_area_height=max_area_height,
height=height)
feature_concat = tf.concat([area_max, size_embed], -1)
elif mode == "sum":
feature_concat = size_embed + area_mean + area_std
elif mode == "sample_concat":
if training:
area_mean += (area_std * tf.random_normal(tf.shape(area_std)))
feature_concat = tf.concat([area_mean, size_embed], -1)
elif mode == "sample_sum":
if training:
area_mean += (area_std * tf.random_normal(tf.shape(area_std)))
feature_concat = area_mean + size_embed
else:
raise ValueError("Unsupported area key mode=%s" % mode)
feature_hidden = tf.layers.dense(inputs=feature_concat,
units=depth,
activation=tf.nn.relu)
area_key = tf.layers.dense(feature_hidden, units=depth)
return area_key
def dot_product_area_attention(q,
k,
v,
bias,
dropout_rate=0.0,
image_shapes=None,
name=None,
attention_image_summary=None,
save_weights_to=None,
dropout_broadcast_dims=None,
max_area_width=1,
max_area_height=1,
memory_height=1,
area_key_mode="mean",
area_value_mode="sum",
top_k_areas=0,
area_temperature=1.0,
training=True):
"""Dot-product area attention.
Args:
q: Tensor with shape [..., length_q, depth_k].
k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
match with q.
v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
match with q.
bias: bias Tensor (see attention_bias())
dropout_rate: a float.
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
attention_image_summary: the callback for making image summary of attention.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
dropout_broadcast_dims: an optional list of integers less than rank of q.
Specifies in which dimensions to broadcast the dropout decisions.
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
memory_height: the height of the memory.
area_key_mode: the mode for computing area keys, which can be "mean",
"concat", "sum", "sample_concat", and "sample_sum".
area_value_mode: the mode for computing area values, which can be either
"mean", or "sum".
top_k_areas: Use the top key areas for attention.
area_temperature: the temperature for attention softmax.
training: indicating if it is in the training mode.
Returns:
Tensor with shape [..., length_q, depth_v].
"""
tf.logging.info("dot_product_area_attention: "
"area_h=%d, area_w=%d, mem_h=%d, "
"area_key_mode=%s, area_value_mode=%s, "
"area_temperature=%f",
max_area_height, max_area_width, memory_height,
area_key_mode, area_value_mode,
area_temperature)
with tf.variable_scope(
name, default_name="dot_product_area_attention",
values=[q, k, v]) as scope:
mem_shape = common_layers.shape_list(k)
batch_size = mem_shape[0]
head_size = mem_shape[1]
length = mem_shape[2]
depth = mem_shape[3]
k_area = compute_area_key(
tf.reshape(k, [-1, length, depth]),
max_area_width=max_area_width,
max_area_height=max_area_height,
height=memory_height,
mode=area_key_mode,
training=training)
if area_value_mode == "mean":
v_area, _, _, _, _ = compute_area_features(
tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width,
max_area_height=max_area_height, height=memory_height)
elif area_value_mode == "max":
v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]),
max_area_width=max_area_width,
max_area_height=max_area_height,
height=memory_height,
fn=tf.reduce_max)
elif area_value_mode == "sum":
_, _, v_area, _, _ = compute_area_features(
tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width,
max_area_height=max_area_height, height=memory_height)
else:
raise ValueError("Unsupported area value mode=%s" % area_value_mode)
k = tf.reshape(k_area, [batch_size, head_size, -1, depth])
v = tf.reshape(v_area, [batch_size, head_size, -1, depth])
logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv]
if bias is not None:
bias = common_layers.cast_like(bias, logits)
with tf.name_scope("compute_area_att_bias", values=[bias]):
bias_shape = common_layers.shape_list(bias)
mem_length = bias_shape[-1]
bias_values = tf.reshape(
tf.to_float(tf.less(bias, -1)), [-1, mem_length, 1])
_, _, padding_sum, _, _ = compute_area_features(
bias_values, max_area_width=max_area_width,
max_area_height=max_area_height, height=memory_height)
bias = tf.where(
tf.cast(tf.to_int32(padding_sum), tf.bool),
tf.fill(tf.shape(padding_sum), -np.inf),
tf.zeros_like(padding_sum, dtype=tf.float32))
bias = tf.reshape(bias,
[bias_shape[0], bias_shape[1],
bias_shape[2], -1])
logits += bias
logits = logits / area_temperature
weights = tf.nn.softmax(logits, name="attention_weights")
if top_k_areas > 0:
tf.logging.info("area_attention top_k_areas=%d", top_k_areas)
top_k = tf.minimum(common_layers.shape_list(weights)[-1], top_k_areas)
top_weights, _ = tf.nn.top_k(weights, k=top_k)
min_values = tf.reduce_min(top_weights, -1, keepdims=True)
weights = tf.where(tf.greater_equal(weights, min_values),
weights, tf.zeros_like(weights))
weights = tf.div(weights, tf.reduce_sum(weights, -1, keepdims=True))
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
# Drop out attention links for each head.
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and attention_image_summary:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
================================================
FILE: tensor2tensor/layers/area_attention_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for area attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensor2tensor.layers import area_attention
import tensorflow.compat.v1 as tf
class AreaAttentionTest(parameterized.TestCase, tf.test.TestCase):
def testComputeAreaFeatures1D(self):
features = tf.constant([[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]],
[[1.1, 2.1], [3.1, 4.1], [5.1, 6.1], [7.1, 8.1],
[9.1, 10.1]]],
dtype=tf.float32)
area_mean, area_std, area_sum, area_height, area_widths = (
area_attention.compute_area_features(features, max_area_width=3,
epsilon=0.))
with self.test_session() as session:
session.run(tf.global_variables_initializer())
res1, res2, res3, res4, res5 = session.run([area_mean, area_std, area_sum,
area_height, area_widths])
self.assertAllClose(((((1, 2), (3, 4), (5, 6), (7, 8), (9, 10),
(2, 3), (4, 5), (6, 7), (8, 9),
(3, 4), (5, 6), (7, 8)),
((1.1, 2.1), (3.1, 4.1), (5.1, 6.1), (7.1, 8.1),
(9.1, 10.1),
(2.1, 3.1), (4.1, 5.1), (6.1, 7.1), (8.1, 9.1),
(3.1, 4.1), (5.1, 6.1), (7.1, 8.1)))),
res1,
msg="mean_1d")
expected_std = np.array([[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
[1, 1], [1, 1], [1, 1], [1, 1],
[1.63299, 1.63299], [1.63299, 1.63299],
[1.63299, 1.63299]],
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
[1, 1], [1, 1], [1, 1], [1, 1],
[1.63299, 1.63299], [1.63299, 1.63299],
[1.63299, 1.63299]]])
self.assertAllClose(expected_std, res2, atol=1e-2, msg="std_1d")
self.assertAllClose([[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[4, 6], [8, 10], [12, 14], [16, 18],
[9, 12], [15, 18], [21, 24]],
[[1.1, 2.1], [3.1, 4.1], [5.1, 6.1], [7.1, 8.1],
[9.1, 10.1],
[4.2, 6.2], [8.2, 10.2], [12.2, 14.2], [16.2, 18.2],
[9.3, 12.3], [15.3, 18.3], [21.3, 24.3]]],
res3,
msg="sum_1d")
self.assertAllEqual([[[1], [1], [1], [1], [1],
[1], [1], [1], [1],
[1], [1], [1]],
[[1], [1], [1], [1], [1],
[1], [1], [1], [1],
[1], [1], [1]]],
res4,
msg="height_1d")
self.assertAllEqual([[[1], [1], [1], [1], [1],
[2], [2], [2], [2],
[3], [3], [3]],
[[1], [1], [1], [1], [1],
[2], [2], [2], [2],
[3], [3], [3]]],
res5,
msg="width_1d")
def testComputeAreaFeatures2D(self):
features = tf.constant([[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
[[1.1, 2.1], [3.1, 4.1], [5.1, 6.1], [7.1, 8.1],
[9.1, 10.1], [11.1, 12.1]]],
dtype=tf.float32)
area_mean, area_std, area_sum, area_height, area_widths = (
area_attention.compute_area_features(features, max_area_width=3,
max_area_height=2,
height=2, epsilon=0.))
with self.test_session() as session:
session.run(tf.global_variables_initializer())
res1, _, res3, res4, res5 = session.run([area_mean, area_std, area_sum,
area_height, area_widths])
expected_means = [[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
[2, 3], [4, 5], [8, 9], [10, 11],
[3, 4], [9, 10],
[4, 5], [6, 7], [8, 9],
[5, 6], [7, 8],
[6, 7]],
[[1.1, 2.1], [3.1, 4.1], [5.1, 6.1], [7.1, 8.1],
[9.1, 10.1], [11.1, 12.1],
[2.1, 3.1], [4.1, 5.1], [8.1, 9.1], [10.1, 11.1],
[3.1, 4.1], [9.1, 10.1],
[4.1, 5.1], [6.1, 7.1], [8.1, 9.1],
[5.1, 6.1], [7.1, 8.1],
[6.1, 7.1]]]
self.assertAllClose(expected_means, res1, msg="mean_1d")
expected_heights = [[[1], [1], [1], [1], [1], [1],
# 1x2
[1], [1], [1], [1],
# 1x3
[1], [1],
# 2x1
[2], [2], [2],
# 2x2
[2], [2],
# 2x3
[2]],
[[1], [1], [1], [1], [1], [1],
# 1x2
[1], [1], [1], [1],
# 1x3
[1], [1],
# 2x1
[2], [2], [2],
# 2x2
[2], [2],
# 2x3
[2]]]
self.assertAllEqual(expected_heights, res4, msg="height_1d")
expected_widths = [[[1], [1], [1], [1], [1], [1],
# 1x2
[2], [2], [2], [2],
# 1x3
[3], [3],
# 2x1
[1], [1], [1],
# 2x2
[2], [2],
# 2x3
[3]],
[[1], [1], [1], [1], [1], [1],
# 1x2
[2], [2], [2], [2],
# 1x3
[3], [3],
# 2x1
[1], [1], [1],
# 2x2
[2], [2],
# 2x3
[3]]]
self.assertAllEqual(expected_widths, res5, msg="width_1d")
sizes = np.multiply(np.array(expected_heights), np.array(expected_widths))
expected_sums = np.multiply(np.array(expected_means), sizes)
self.assertAllClose(expected_sums, res3, msg="sum_1d")
def testAreaMean(self):
batch_size = 256
feature_len = 100
memory_height = 10
heads = 2
key_len = 2
depth = 128
max_area_height = 3
max_area_width = 3
queries = tf.random_uniform([batch_size, heads, key_len, depth],
minval=-10.0, maxval=10.0)
features = tf.random_uniform([batch_size, heads, feature_len, depth],
minval=-10.0, maxval=10.0)
target_values = tf.random_uniform([batch_size, heads, key_len, depth],
minval=-0.2, maxval=0.2)
keys = tf.layers.dense(features, units=depth)
values = tf.layers.dense(features, units=depth)
mean_attention = area_attention.dot_product_area_attention(
queries, keys, values,
bias=None,
area_key_mode="mean",
name="mean_key",
max_area_width=max_area_width,
max_area_height=max_area_height,
memory_height=memory_height)
mean_gradients = tf.gradients(
tf.reduce_mean(
tf.pow(target_values - mean_attention, 2)), features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
result = session.run([mean_gradients])
self.assertFalse(np.any(np.logical_not(np.isfinite(result))))
def test2DAreaMax(self):
batch_size = 256
feature_len = 100
memory_height = 10
heads = 2
key_len = 6
depth = 128
max_area_height = 3
max_area_width = 3
queries = tf.random_uniform([batch_size, heads, key_len, depth],
minval=-10.0, maxval=10.0)
features = tf.random_uniform([batch_size, heads, feature_len, depth],
minval=-10.0, maxval=10.0)
target_values = tf.random_uniform([batch_size, heads, key_len, depth],
minval=-0.2, maxval=0.2)
keys = tf.layers.dense(features, units=depth)
values = tf.layers.dense(features, units=depth)
max_attention = area_attention.dot_product_area_attention(
queries, keys, values,
bias=None,
area_key_mode="max",
area_value_mode="max",
name="max_key",
max_area_width=max_area_width,
max_area_height=max_area_height,
memory_height=memory_height)
max_gradients = tf.gradients(tf.reduce_mean(
tf.pow(target_values - max_attention, 2)), features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
result1, result2 = session.run([max_gradients, max_attention])
self.assertFalse(np.any(np.logical_not(np.isfinite(result1))))
self.assertFalse(np.any(np.logical_not(np.isfinite(result2))))
def test1DAreaMax(self):
batch_size = 256
feature_len = 100
heads = 2
key_len = 15
depth = 128
max_area_width = 3
queries = tf.random_uniform([batch_size, heads, key_len, depth],
minval=-10.0, maxval=10.0)
features = tf.random_uniform([batch_size, heads, feature_len, depth],
minval=-10.0, maxval=10.0)
feature_length = tf.constant(
np.concatenate(
(np.random.randint(max_area_width, feature_len, [batch_size - 1]),
np.array([feature_len])), axis=0), tf.int32)
base_mask = tf.expand_dims(tf.sequence_mask(feature_length), 1)
mask = tf.expand_dims(base_mask, 3)
mask = tf.tile(mask, [1, heads, 1, depth])
features = tf.where(mask, features, tf.zeros_like(features))
# [batch, 1, 1, memory_length]
bias_mask = tf.expand_dims(base_mask, 1)
bias = tf.where(
bias_mask,
tf.zeros_like(bias_mask, tf.float32),
tf.ones_like(bias_mask, tf.float32) * -1e9)
target_values = tf.random_uniform([batch_size, heads, key_len, depth],
minval=-0.2, maxval=0.2)
keys = tf.layers.dense(features, units=depth)
values = tf.layers.dense(features, units=depth)
max_attention = area_attention.dot_product_area_attention(
queries, keys, values,
bias=bias,
area_key_mode="max",
area_value_mode="max",
name="max_key",
max_area_width=max_area_width)
max_gradients = tf.gradients(
tf.reduce_mean(
tf.pow(target_values - max_attention, 2)), features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
result1, result2 = session.run([max_gradients, max_attention])
self.assertFalse(np.any(np.logical_not(np.isfinite(result1))))
self.assertFalse(np.any(np.logical_not(np.isfinite(result2))))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/common_attention.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import itertools
import math
import operator
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
from tensor2tensor.layers import area_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import contrib
from tensor2tensor.utils import expert_utils
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import function
from tensorflow.python.ops import inplace_ops
# pylint: enable=g-direct-tensorflow-import
# TODO(lukaszkaiser): remove this function when not needed any more.
def layers():
return common_layers.layers()
def large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
def mixed_precision_is_enabled(
activation_dtype=None, weight_dtype=None, hparams=None):
assert not (hparams and (activation_dtype or weight_dtype)), (
"Provide only hparams or activation_dtype and weight_dtype")
if (hparams and hasattr(hparams, "activation_dtype") and
hasattr(hparams, "weight_dtype")):
activation_dtype = hparams.activation_dtype
weight_dtype = hparams.weight_dtype
return activation_dtype == tf.float16 and weight_dtype == tf.float32
def maybe_upcast(logits,
activation_dtype=None, weight_dtype=None, hparams=None):
if mixed_precision_is_enabled(activation_dtype, weight_dtype, hparams):
return tf.cast(logits, tf.float32)
return logits
# Struct containing the sequences ids and order on a batch (are send to the
# expert to allow them to compute the bias mask)
BatchInfo = collections.namedtuple("BatchInfo", "coordinates, order")
_expert_count = 0
def get_standardized_layers(hparams, dp=None):
"""Get the common attention and feed-forward layers.
The returned layer functions will have the following signature:
y, extra_loss = fct(x)
extra_loss is set to 0.0 if the layer doesn't have extra loss.
If dp is provided, the layers will be distributed within the devices.
If moe wants to be used, both dp and model need to be set.
Args:
hparams (tf.HParams): the model hparameters
dp (expert_utils.Parallelism): A data parallelism object. If not given,
the dp calls are simply ignored.
Returns:
dict[str:fct]: A dictionary containing the standardized functions
"""
def partial(fct, *args, **kwargs):
"""Same as functools.partial but with functools.wraps."""
return functools.wraps(fct)(functools.partial(fct, *args, **kwargs))
def register_layer(
fct_in,
default_args=None,
default_kwargs=None,
use_dp=True,
recompute_grad=False,
):
"""Turn a function into its standardized version.
Args:
fct_in (fct): The function to register
default_args (list): The default parameters to add to the function.
default_kwargs (dict): The default parameters to add to the function.
Those arguments can be overwritten when calling the function.
use_dp (bool): Wrap the function call within a dataparallelism object if
dp is available. Some layers (like MOE) must be called without dp.
recompute_grad (bool): If True, recompute the function during the
backward pass to save memory
Returns:
fct: the standardized layer function.
"""
# The kwargs given when calling the function overwrite the default ones
fct_in = partial(fct_in, *(default_args or []), **(default_kwargs or {}))
@functools.wraps(fct_in)
def decorator(x, *args, **kwargs):
"""Call the layer function."""
fct = fct_in # For closure. Could use nonlocal with Python 3
# Eventually create the memory optimized version of the function
if recompute_grad:
fct = partial(fct, **kwargs) # recompute_grad only accept args
fct = common_layers.recompute_grad(fct)
kwargs = {}
# Eventually use dp (if given and not MoE)
if use_dp and dp is not None:
y = dp(fct, x, *args, **kwargs)
else:
y = fct(x, *args, **kwargs)
# Eventually capture the extra loss
extra_loss = 0.0
if isinstance(y, tuple):
y, extra_loss = y
return y, extra_loss
return decorator
total_key_depth = hparams.attention_key_channels or hparams.hidden_size
total_value_depth = hparams.attention_value_channels or hparams.hidden_size
# Attention layers:
# === Multi-head full attention layer ===
multihead_attention_fn = register_layer(
multihead_attention,
default_kwargs=dict(
memory_antecedent=None, # Self-attention by default
bias=None,
total_key_depth=total_key_depth,
total_value_depth=total_value_depth,
output_depth=hparams.hidden_size,
num_heads=hparams.num_heads,
dropout_rate=hparams.attention_dropout,
))
# === Memory efficient full-attention layer ===
# Save memory by not storing the activations and
# recomputing them during the backward pass
memeff_attention_base_fn = register_layer(
multihead_attention,
default_kwargs=dict(
total_key_depth=total_key_depth,
total_value_depth=total_value_depth,
output_depth=hparams.hidden_size,
num_heads=hparams.num_heads,
dropout_rate=hparams.attention_dropout,
),
recompute_grad=True,
)
def memeff_attention_fn(*args, **kwargs):
"""Modify args/kwargs for compatibility with recompute_grad."""
kwargs = kwargs.copy()
assert len(args) == 1
x = args[0]
memory_antecedent = kwargs.pop("memory_antecedent", x) # Same as x if None
if kwargs.get("bias", None) is not None: # Case where bias has been set
args = (x, memory_antecedent, kwargs.pop("bias"))
else:
# Otherwise, only 2 args. This is necessary as recompute_grad does not
# support None values.
args = (x, memory_antecedent)
return memeff_attention_base_fn(*args, **kwargs)
# === Local attention (unmasked) layer ===
# Reuse same parameters as multihead_attention
# Don't mask the future
local_attention_fn = partial(
multihead_attention_fn,
block_length=hparams.attention_loc_block_length,
block_width=hparams.attention_loc_block_width,
attention_type="local_unmasked",
)
# === Local attention (masked) layer ===
# Reuse same parameters as multihead_attention
# Only works for self attention. Always mask the future.
local_attention_masked_fn = partial(
multihead_attention_fn,
block_length=hparams.attention_loc_block_length,
attention_type="local_mask_right",
)
# === Masked memory-compressed multihead self attention layer ===
# Only works for self attention. Always mask the future.
compressed_attention_masked_fn = register_layer(
multihead_self_attention_reduced,
default_kwargs=dict(
factor=hparams.attention_red_factor,
nonlinearity=hparams.attention_red_nonlinearity,
reduction_type=hparams.attention_red_type,
multihead_params=dict(
total_key_depth=total_key_depth,
total_value_depth=total_value_depth,
num_heads=hparams.num_heads,
dropout_rate=hparams.attention_dropout,
),
),
)
# === Unmasked memory-compressed multihead self attention layer ===
# Only works for self attention. Never mask the future. Bias never added
compressed_attention_fn = partial(
compressed_attention_masked_fn,
add_mask=False,
)
# Feed-forwards layers:
# === FC layer ===
conv_hidden_relu = register_layer(
common_layers.conv_hidden_relu,
default_kwargs=dict(
hidden_size=hparams.filter_size,
output_size=hparams.hidden_size,
dropout=hparams.relu_dropout,
),
)
# === Separable convolution layer ===
# No mask applied
sep_conv_relu = partial(
conv_hidden_relu,
padding="SAME",
# Parameters copied from the transformer model, could add hparams
kernel_size=(3, 1),
second_kernel_size=(31, 1),
)
# === Separable convolution layer (masked version) ===
# Mask the future
sep_conv_relu_masked = partial(
sep_conv_relu,
padding="LEFT", # Mask future for decoder
)
# Define all available layers
cur_layers = dict(
# Attention layers:
a=multihead_attention_fn, # Multihead full attention
loc=local_attention_fn, # Local attention
locm=local_attention_masked_fn, # Local attention (masked)
red=compressed_attention_fn, # Memory-compressed attention
redm=compressed_attention_masked_fn, # Memory-compressed att (masked)
mem=memeff_attention_fn, # Memory efficient
# Feed-forward layers:
fc=conv_hidden_relu, # Fully connected
sep=sep_conv_relu, # Separable convolution (unmasked)
sepm=sep_conv_relu_masked, # Separable convolution (masked)
)
return cur_layers
def add_standard_attention_hparams(hparams):
"""Adds the hparams used by get_standardized_layers."""
# All hyperparameters ending in "dropout" are automatically set to 0.0
# when not in training mode.
# hparams used and which should have been defined outside (in
# common_hparams):
# Global flags
# hparams.mode
# hparams.hidden_size
# Pre-post processing flags
# hparams.layer_preprocess_sequence
# hparams.layer_postprocess_sequence
# hparams.layer_prepostprocess_dropout
# hparams.norm_type
# hparams.norm_epsilon
# Mixture-of-Expert flags
# hparams.moe_hidden_sizes
# hparams.moe_num_experts
# hparams.moe_k
# hparams.moe_loss_coef
# Attention layers flags
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("attention_key_channels", 0)
hparams.add_hparam("attention_value_channels", 0)
hparams.add_hparam("attention_dropout", 0.0)
# Attention: Local
hparams.add_hparam("attention_loc_block_length", 256)
# Attention: Local (unmasked only): How much to look left.
hparams.add_hparam("attention_loc_block_width", 128)
# Attention: Memory-compressed
hparams.add_hparam("attention_red_factor", 3)
hparams.add_hparam("attention_red_type", "conv")
hparams.add_hparam("attention_red_nonlinearity", "none")
# Fully connected layers flags
# To be more consistent, should use filter_size to also control the MOE
# size if moe_hidden_sizes not set.
hparams.add_hparam("filter_size", 2048)
hparams.add_hparam("relu_dropout", 0.0)
return hparams
def encoder_decoder_attention_loss(expected_attention_logits,
actual_attentions,
loss_type="kl_divergence",
loss_multiplier=1.0):
"""Computes encdec attention loss between expected and actual attentions.
Args:
expected_attention_logits: Tensor storing the expected encoder-decoder
attention logits with shape [batch_size, target_length, input_length].
actual_attentions: Dictionary with actual attention logits for different
attention types and hidden layers.
loss_type: type of the loss function.
loss_multiplier: multiplier for the attention loss.
Returns:
KL_divergence loss between the actual and expected attention logits.
"""
def combine_attentions(attention_list):
"""Combine different layer attentions and then average over layers/heads."""
# Stack all hidden layer attention tensors to get a tensor with shape
# [num_hidden_layers, batch_size, num_heads, target_length, input_length].
attentions = tf.stack(attention_list)
# Reduce mean across all layers (axis=0) and all heads (axis=2) to get a
# tensor with shape [batch_size, target_length, input_length].
return tf.reduce_mean(attentions, [0, 2])
def kl_divergence_loss(expected_logits, actual_logits):
p = tfp.distributions.Categorical(logits=expected_logits)
q = tfp.distributions.Categorical(logits=actual_logits)
return tfp.distributions.kl_divergence(p, q)
def mse_loss(expected_logits, actual_weights):
expected_weights = tf.nn.softmax(expected_logits)
return tf.losses.mean_squared_error(expected_weights, actual_weights)
# For each hidden layer, we have attention-logit and attention-weight tensors
# with shape [batch_size, num_heads, target_length, input_length].
loss = 0.0
if loss_type == "mse":
actual_encdec_attention_weights = [
t for layer_key, t in actual_attentions.items()
if "encdec_attention" in layer_key and not layer_key.endswith("/logits")
]
actual_attention_weights = combine_attentions(
actual_encdec_attention_weights)
loss = mse_loss(expected_attention_logits, actual_attention_weights)
else:
actual_encdec_attention_logits = [
t for layer_key, t in actual_attentions.items()
if "encdec_attention" in layer_key and layer_key.endswith("/logits")
]
actual_attention_logits = combine_attentions(actual_encdec_attention_logits)
loss = kl_divergence_loss(expected_attention_logits,
actual_attention_logits)
return loss * loss_multiplier
@expert_utils.add_name_scope()
def get_timing_signal_1d(length,
channels,
min_timescale=1.0,
max_timescale=1.0e4,
start_index=0):
"""Gets a bunch of sinusoids of different frequencies.
Each channel of the input Tensor is incremented by a sinusoid of a different
frequency and phase.
This allows attention to learn to use absolute and relative positions.
Timing signals should be added to some precursors of both the query and the
memory inputs to attention.
The use of relative position is possible because sin(x+y) and cos(x+y) can be
expressed in terms of y, sin(x) and cos(x).
In particular, we use a geometric sequence of timescales starting with
min_timescale and ending with max_timescale. The number of different
timescales is equal to channels / 2. For each timescale, we
generate the two sinusoidal signals sin(timestep/timescale) and
cos(timestep/timescale). All of these sinusoids are concatenated in
the channels dimension.
Args:
length: scalar, length of timing signal sequence.
channels: scalar, size of timing embeddings to create. The number of
different timescales is equal to channels / 2.
min_timescale: a float
max_timescale: a float
start_index: index of first position
Returns:
a Tensor of timing signals [1, length, channels]
"""
position = tf.to_float(tf.range(length) + start_index)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
tf.maximum(tf.to_float(num_timescales) - 1, 1))
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
# Please note that this slightly differs from the published paper.
# See a discussion here: https://github.com/tensorflow/tensor2tensor/pull/177
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
signal = tf.reshape(signal, [1, length, channels])
return signal
@expert_utils.add_name_scope()
def add_timing_signal_1d(x,
min_timescale=1.0,
max_timescale=1.0e4,
start_index=0):
"""Adds a bunch of sinusoids of different frequencies to a Tensor.
Each channel of the input Tensor is incremented by a sinusoid of a different
frequency and phase.
This allows attention to learn to use absolute and relative positions.
Timing signals should be added to some precursors of both the query and the
memory inputs to attention.
The use of relative position is possible because sin(x+y) and cos(x+y) can be
expressed in terms of y, sin(x) and cos(x).
In particular, we use a geometric sequence of timescales starting with
min_timescale and ending with max_timescale. The number of different
timescales is equal to channels / 2. For each timescale, we
generate the two sinusoidal signals sin(timestep/timescale) and
cos(timestep/timescale). All of these sinusoids are concatenated in
the channels dimension.
Args:
x: a Tensor with shape [batch, length, channels]
min_timescale: a float
max_timescale: a float
start_index: index of first position
Returns:
a Tensor the same shape as x.
"""
length = common_layers.shape_list(x)[1]
channels = common_layers.shape_list(x)[2]
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale,
start_index)
return x + common_layers.cast_like(signal, x)
@expert_utils.add_name_scope()
def get_layer_timing_signal_learned_1d(channels, layer, num_layers):
"""get n-dimensional embedding as the layer (vertical) timing signal.
Adds embeddings to represent the position of the layer in the tower.
Args:
channels: dimension of the timing signal
layer: layer num
num_layers: total number of layers
Returns:
a Tensor of timing signals [1, 1, channels].
"""
shape = [num_layers, 1, 1, channels]
layer_embedding = (
tf.get_variable(
"layer_embedding",
shape,
initializer=tf.random_normal_initializer(0, channels**-0.5)) *
(channels**0.5))
return layer_embedding[layer, :, :, :]
@expert_utils.add_name_scope()
def add_layer_timing_signal_learned_1d(x, layer, num_layers):
"""Add n-dimensional embedding as the layer (vertical) timing signal.
Adds embeddings to represent the position of the layer in the tower.
Args:
x: a tensor with shape [batch, length, depth]
layer: layer num
num_layers: total number of layers
Returns:
a Tensor the same shape as x.
"""
channels = common_layers.shape_list(x)[-1]
signal = get_layer_timing_signal_learned_1d(channels, layer, num_layers)
x += signal
return x
@expert_utils.add_name_scope()
def get_layer_timing_signal_sinusoid_1d(channels, layer, num_layers):
"""Add sinusoids of different frequencies as layer (vertical) timing signal.
Args:
channels: dimension of the timing signal
layer: layer num
num_layers: total number of layers
Returns:
a Tensor of timing signals [1, 1, channels].
"""
signal = get_timing_signal_1d(num_layers, channels)
layer_signal = tf.expand_dims(signal[:, layer, :], axis=1)
return layer_signal
@expert_utils.add_name_scope()
def add_layer_timing_signal_sinusoid_1d(x, layer, num_layers):
"""Add sinusoids of different frequencies as layer (vertical) timing signal.
Args:
x: a Tensor with shape [batch, length, channels]
layer: layer num
num_layers: total number of layers
Returns:
a Tensor the same shape as x.
"""
channels = common_layers.shape_list(x)[-1]
signal = get_layer_timing_signal_sinusoid_1d(channels, layer, num_layers)
return x + signal
@expert_utils.add_name_scope()
def add_timing_signals_given_positions(x,
positions,
min_timescale=1.0,
max_timescale=1.0e4):
"""Adds sinusoids of diff frequencies to a Tensor, with timing positions given.
Args:
x: a Tensor with shape [batch, length, channels]
positions: a list of positions, each of which can either be a Tensor of
shape [batch, length] or None for a default of (0..length]
min_timescale: a float
max_timescale: a float
Returns:
a Tensor the same shape as x.
"""
shape = common_layers.shape_list(x)
batch = shape[0]
length = shape[1]
channels = shape[2]
num_dims = len(positions)
num_timescales = channels // (num_dims * 2)
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.to_float(num_timescales) - 1))
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
for dim, position in enumerate(positions):
if position is None:
# Create a [batch, length] Tensor of incrementing positions 0..length-1.
position = tf.tile(
tf.transpose(tf.expand_dims(tf.range(0, length), axis=1)), [batch, 1])
scaled_time = (
tf.expand_dims(tf.to_float(position), 2) *
tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0))
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
prepad = dim * 2 * num_timescales
postpad = channels - (dim + 1) * 2 * num_timescales
signal = tf.pad(signal, [[0, 0], [0, 0], [prepad, postpad]])
signal = common_layers.cast_like(signal, x)
x += signal
return x
@expert_utils.add_name_scope()
def add_timing_signals_from_features(x,
features,
position_features,
min_timescale=1.0,
max_timescale=1.0e4):
"""Adds timing signals from features named in `position_features`.
Args:
x: a Tensor with shape [batch, length, channels]
features: a features dictionary
position_features: a comma-delimited string where each item is either a
feature key or the empty string (which denotes a default position tensor
of [0..length])
min_timescale: a float
max_timescale: a float
Returns:
a Tensor the same shape as x.
"""
return add_timing_signals_given_positions(x, [
features.get(position_feature)
for position_feature in position_features.split(",")
], min_timescale, max_timescale)
@expert_utils.add_name_scope()
def add_timing_signal_1d_given_position(x,
position,
min_timescale=1.0,
max_timescale=1.0e4):
"""Adds sinusoids of diff frequencies to a Tensor, with timing position given.
Args:
x: a Tensor with shape [batch, length, channels]
position: a Tensor with shape [batch, length]
min_timescale: a float
max_timescale: a float
Returns:
a Tensor the same shape as x.
"""
channels = common_layers.shape_list(x)[2]
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.to_float(num_timescales) - 1))
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
scaled_time = (
tf.expand_dims(tf.to_float(position), 2) * tf.expand_dims(
tf.expand_dims(inv_timescales, 0), 0))
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]])
signal = common_layers.cast_like(signal, x)
return x + signal
@expert_utils.add_name_scope()
def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4):
"""Adds a bunch of sinusoids of different frequencies to a Tensor.
Each channel of the input Tensor is incremented by a sinusoid of a different
frequency and phase in one of the positional dimensions.
This allows attention to learn to use absolute and relative positions.
Timing signals should be added to some precursors of both the query and the
memory inputs to attention.
The use of relative position is possible because sin(a+b) and cos(a+b) can be
expressed in terms of b, sin(a) and cos(a).
x is a Tensor with n "positional" dimensions, e.g. one dimension for a
sequence or two dimensions for an image
We use a geometric sequence of timescales starting with
min_timescale and ending with max_timescale. The number of different
timescales is equal to channels // (n * 2). For each timescale, we
generate the two sinusoidal signals sin(timestep/timescale) and
cos(timestep/timescale). All of these sinusoids are concatenated in
the channels dimension.
Args:
x: a Tensor with shape [batch, d1 ... dn, channels]
min_timescale: a float
max_timescale: a float
Returns:
a Tensor the same shape as x.
"""
num_dims = len(x.get_shape().as_list()) - 2
channels = common_layers.shape_list(x)[-1]
num_timescales = channels // (num_dims * 2)
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.to_float(num_timescales) - 1))
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
for dim in range(num_dims):
length = common_layers.shape_list(x)[dim + 1]
position = tf.to_float(tf.range(length))
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
inv_timescales, 0)
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
prepad = dim * 2 * num_timescales
postpad = channels - (dim + 1) * 2 * num_timescales
signal = tf.pad(signal, [[0, 0], [prepad, postpad]])
for _ in range(1 + dim):
signal = tf.expand_dims(signal, 0)
for _ in range(num_dims - 1 - dim):
signal = tf.expand_dims(signal, -2)
x += signal
return x
def add_positional_embedding(x, max_length, name=None, positions=None):
"""Adds positional embedding.
Args:
x: Tensor with shape [batch, length, depth].
max_length: int representing static maximum size of any dimension.
name: str representing name of the embedding tf.Variable.
positions: Tensor with shape [batch, length].
Returns:
Tensor of same shape as x.
"""
with tf.name_scope("add_positional_embedding"):
_, length, depth = common_layers.shape_list(x)
var = tf.cast(tf.get_variable(name, [max_length, depth]), x.dtype)
if positions is None:
pad_length = tf.maximum(0, length - max_length)
sliced = tf.cond(
tf.less(length, max_length),
lambda: tf.slice(var, [0, 0], [length, -1]),
lambda: tf.pad(var, [[0, pad_length], [0, 0]]))
return x + tf.expand_dims(sliced, 0)
else:
return x + tf.gather(var, tf.to_int32(positions))
def add_positional_embedding_nd(x, max_length, name=None):
"""Adds n-dimensional positional embedding.
The embeddings add to all positional dimensions of the tensor.
Args:
x: Tensor with shape [batch, p1 ... pn, depth]. It has n positional
dimensions, i.e., 1 for text, 2 for images, 3 for video, etc.
max_length: int representing static maximum size of any dimension.
name: str representing name of the embedding tf.Variable.
Returns:
Tensor of same shape as x.
"""
with tf.name_scope("add_positional_embedding_nd"):
x_shape = common_layers.shape_list(x)
num_dims = len(x_shape) - 2
depth = x_shape[-1]
base_shape = [1] * (num_dims + 1) + [depth]
base_start = [0] * (num_dims + 2)
base_size = [-1] + [1] * num_dims + [depth]
for i in range(num_dims):
shape = base_shape[:]
start = base_start[:]
size = base_size[:]
shape[i + 1] = max_length
size[i + 1] = x_shape[i + 1]
var = tf.get_variable(
name + "_%d" % i,
shape,
initializer=tf.random_normal_initializer(0, depth**-0.5))
var = var * depth**0.5
x += tf.slice(var, start, size)
return x
def make_edge_vectors(adjacency_matrix, num_edge_types, depth, name=None):
"""Gets edge vectors for the edge types in the adjacency matrix.
Args:
adjacency_matrix: A [batch, num_nodes, num_nodes] tensor of ints.
num_edge_types: Number of different edge types
depth: Number of channels
name: a string
Returns:
A [batch, num_nodes, num_nodes, depth] vector of tensors
"""
with tf.variable_scope(name, default_name="edge_vectors"):
att_adj_vectors_shape = [num_edge_types, depth]
adjacency_matrix_shape = common_layers.shape_list(adjacency_matrix)
adj_vectors = (
tf.get_variable(
"adj_vectors",
att_adj_vectors_shape,
initializer=tf.random_normal_initializer(0, depth**-0.5)) *
(depth**0.5))
# Avoiding gathers so that it works on TPUs
# adjacency_matrix_one_hot has shape
# [batch, num_nodes, num_nodes, num_edge_types]
adjacency_matrix_one_hot = tf.one_hot(adjacency_matrix, num_edge_types)
att_adj_vectors = tf.matmul(
tf.reshape(tf.to_float(adjacency_matrix_one_hot), [-1, num_edge_types]),
adj_vectors)
return tf.reshape(att_adj_vectors,
[adjacency_matrix_shape[0], adjacency_matrix_shape[1],
adjacency_matrix_shape[2], depth])
class LshGating(object):
"""Class to split key/queries into separate buckets."""
def __init__(self, depth, nb_hyperplanes, nb_replicat=1, trainable=False):
"""Construct the gating function parameters.
Compute the gates for a single head.
Args:
depth (int): Dimension of the key/queries to dispatch
nb_hyperplanes (int): Nb of vectors use to split the space. Will determine
the number of buckets (2^nb_hyperplanes - 1).
nb_replicat (int): Redundancy to avoid the edge cases (to be in one bucket
the input should be in a majority)
trainable (bool): If True, a balance loss is added to force the hyperplane
to divide the key/query space evenly
"""
self.depth = depth
self.nb_hyperplanes = nb_hyperplanes
self.nb_buckets = 2**nb_hyperplanes
self.nb_replicat = nb_replicat # Unused for now
self.trainable = trainable # Unused for now
self.dispatchers = {}
assert self.nb_replicat == 1 # For now
with tf.variable_scope("lsh_gating"):
# Vectors defining the hyperplanes
self.t_vectors = tf.get_variable(
"vector",
shape=(self.depth, self.nb_hyperplanes * self.nb_replicat),
dtype=tf.float32,
trainable=self.trainable,
)
# Projection vector from the bit space to similarity score space
self.t_group = tf.constant(
[self._idx_to_bits(i) for i in range(self.nb_buckets)],
dtype=tf.float32,
name="group")
def _idx_to_bits(self, i):
"""Convert an group index to its bit representation."""
bits = bin(i)[2:].zfill(self.nb_hyperplanes) # Pad the bits str with 0
return [-1.0 if b == "0" else 1.0 for b in bits]
@expert_utils.add_name_scope("lsh_gating")
def get_gates(self, x):
"""Return the bucket id of the given tensor.
Args:
x (tf.Tensor): float32 of shape [length, depth]
Returns:
tf.Tensor: One-hot vector int64 of shape [heads, length, nb_buckets]
containing the id of the bucket
"""
# The balance loss don't propagate to the rest of the network
x = tf.stop_gradient(x)
# [length, depth] * [depth, nb_vectors * replicat]
x = tf.matmul(x, self.t_vectors)
# [length, nb_vector * replicat]
x = tf.sign(x) # Get on which side of the hyperplane the keys are.
# x = tf.reshape(x, [-1, nb_replicat, nb_vector])
# [length, replicat, nb_vector] * [nb_vector, 2^nb_vector - 1]
x = tf.matmul(x, self.t_group, transpose_b=True) / self.nb_hyperplanes
# We get a similarity score for each of the group between [-1, 1]
# [length, (replicat,) 2^nb_vector - 1]
# Do an argmax to get the most likely group for each replicat
x = tf.argmax(x, axis=-1)
# [length(, replicat)]
# One-hot for compatibility with the sparse dispatcher
x = tf.one_hot(x, self.nb_buckets)
# TODO(epot): Use a loss to force an even distribution
return x
@expert_utils.add_name_scope()
def embedding_to_padding(emb):
"""Calculates the padding mask based on which embeddings are all zero.
We have hacked symbol_modality to return all-zero embeddings for padding.
Args:
emb: a Tensor with shape [..., depth].
Returns:
a float Tensor with shape [...]. Each element is 1 if its corresponding
embedding vector is all zero, and is 0 otherwise.
"""
emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1)
return tf.to_float(tf.equal(emb_sum, 0.0))
@expert_utils.add_name_scope()
def padding_to_length(padding):
"""Calculate the length of mask based on padding.
Args:
padding: a Tensor with shape [..., length].
Returns:
a Tensor with shape [...].
"""
non_padding = 1.0 - padding
return tf.to_int32(tf.reduce_sum(non_padding, axis=-1))
@expert_utils.add_name_scope()
def attention_bias_local(length, max_backward, max_forward):
"""Create an bias tensor to be added to attention logits.
A position may attend to positions at most max_distance from it,
forward and backwards.
This does not actually save any computation.
Args:
length: int
max_backward: int, maximum distance backward to attend. Negative values
indicate unlimited.
max_forward: int, maximum distance forward to attend. Negative values
indicate unlimited.
Returns:
a `Tensor` with shape [1, 1, length, length].
"""
band = common_layers.ones_matrix_band_part(
length,
length,
max_backward,
max_forward,
out_shape=[1, 1, length, length])
return -1e9 * (1.0 - band)
@expert_utils.add_name_scope()
def attention_bias_lower_triangle(length):
"""Create an bias tensor to be added to attention logits.
Allows a query to attend to all positions up to and including its own.
Args:
length: a Scalar.
Returns:
a `Tensor` with shape [1, 1, length, length].
"""
return attention_bias_local(length, -1, 0)
@expert_utils.add_name_scope()
def attention_bias_same_segment(query_segment_id, memory_segment_id):
"""Create an bias tensor to be added to attention logits.
Positions with the same segment_ids can see each other.
Args:
query_segment_id: a float `Tensor` with shape [batch, query_length].
memory_segment_id: a float `Tensor` with shape [batch, memory_length].
Returns:
a `Tensor` with shape [batch, 1, query_length, memory_length].
"""
ret = (tf.to_float(
tf.not_equal(
tf.expand_dims(query_segment_id, 2),
tf.expand_dims(memory_segment_id, 1))) *
large_compatible_negative(memory_segment_id.dtype))
return tf.expand_dims(ret, axis=1)
@expert_utils.add_name_scope()
def attention_bias_ignore_padding(memory_padding):
"""Create an bias tensor to be added to attention logits.
Args:
memory_padding: a float `Tensor` with shape [batch, memory_length].
Returns:
a `Tensor` with shape [batch, 1, 1, memory_length].
"""
ret = memory_padding * large_compatible_negative(memory_padding.dtype)
return tf.expand_dims(tf.expand_dims(ret, axis=1), axis=1)
@expert_utils.add_name_scope()
def attention_bias_to_padding(attention_bias,
cast_fn=(lambda x: tf.cast(x, tf.float32))):
"""Inverse of attention_bias_ignore_padding().
Args:
attention_bias: a `Tensor` with shape [batch, 1, 1, memory_length], as
returned by attention_bias_ignore_padding().
cast_fn: function used to cast to output type.
Returns:
a Tensor with shape [batch, memory_length] with 1.0 in padding positions
and 0.0 in non-padding positions. Type is determined by cast_fn.
"""
# `attention_bias` is a large negative number in padding positions and 0.0
# elsewhere.
return tf.squeeze(cast_fn(tf.less(attention_bias, -1)), axis=[1, 2])
@expert_utils.add_name_scope()
def attention_bias_prepend_inputs_full_attention(padding):
"""Create a bias tensor for prepend_mode="prepend_inputs_full_attention".
See prepend_inputs in common_hparams.py.
Produces a bias tensor to be used in self-attention.
This bias tensor allows for full connectivity in the "inputs" part of
the sequence and masked connectivity in the targets part.
Args:
padding: a float `Tensor` with shape [batch, length] with
ones in positions corresponding to padding. In each row, a single
padding position separates the input part from the target part.
Returns:
a `Tensor` with shape [batch, 1, length, length].
"""
# Everything past the first padding position is part of the target.
# This Tensor has zeros for the source portion and separator,
# and ones for the target portion.
in_target = tf.cumsum(padding, axis=1, exclusive=True)
# The position within the target, or 0 if part of the source.
target_pos = tf.cumsum(in_target, axis=1)
# A position with a lesser target_pos cannot see a position with greater
# target_pos.
illegal_connections = tf.greater(
tf.expand_dims(target_pos, 1), tf.expand_dims(target_pos, 2))
bias = tf.to_float(illegal_connections) * -1e9
bias = tf.expand_dims(bias, 1)
return bias
@expert_utils.add_name_scope()
def attention_bias_proximal(length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = tf.to_float(tf.range(length))
diff = tf.expand_dims(r, 0) - tf.expand_dims(r, 1)
return tf.expand_dims(tf.expand_dims(-tf.log1p(tf.abs(diff)), 0), 0)
@expert_utils.add_name_scope()
def attention_bias_batch(batch_coordinates_q,
batch_coordinates_k=None,
condition_fn=None):
"""Generate a mask to prevent the batch to attend to each others.
Args:
batch_coordinates_q: Int-like Tensor of shape [length_q, 1] containing the
coordinates of the batches
batch_coordinates_k: Int-like Tensor of shape [length_k, 1] containing the
coordinates of the batches. If None, do self-attention.
condition_fn: Callable defining the attention mask.
Returns:
Float-like Tensor of shape [length_q, length_k] containing either 0 or
-infinity (-1e9).
"""
if batch_coordinates_k is None:
batch_coordinates_k = batch_coordinates_q
# Convert to float first because of b/25387198.
def to_float(bc):
bc = tf.squeeze(bc, 1)
bc = tf.to_float(bc)
return bc
# Broadcast to create [length_q, length_k] mask.
bc_v = tf.expand_dims(to_float(batch_coordinates_q), 1)
bc_h = tf.expand_dims(to_float(batch_coordinates_k), 0)
bias_batch = bc_h - bc_v
bias_batch = condition_fn(bias_batch)
bias_batch *= -1e9
return bias_batch
# Mask to prevent individual sequences of the same batch to attend to each other
attention_bias_coordinates = functools.partial(
attention_bias_batch,
condition_fn=lambda bias: tf.minimum(1.0, tf.abs(bias)),
)
# Mask similar to upper triangular mask, but allow dispatching
attention_bias_future = functools.partial(
attention_bias_batch,
# Elems can attend to themselves (otherwise would use bias_batch + 1.0).
# No tf.abs to consider the order,
# tf.maximum and tf.minimum to threshold the values.
condition_fn=lambda bias: tf.maximum(0.0, tf.minimum(1.0, bias)),
)
@expert_utils.add_name_scope()
def split_last_dimension(x, n):
"""Reshape x so that the last dimension becomes two dimensions.
The first of these two dimensions is n.
Args:
x: a Tensor with shape [..., m]
n: an integer.
Returns:
a Tensor with shape [..., n, m/n]
"""
x_shape = common_layers.shape_list(x)
m = x_shape[-1]
if isinstance(m, int) and isinstance(n, int):
assert m % n == 0
return tf.reshape(x, x_shape[:-1] + [n, m // n])
@expert_utils.add_name_scope()
def combine_last_two_dimensions(x):
"""Reshape x so that the last two dimension become one.
Args:
x: a Tensor with shape [..., a, b]
Returns:
a Tensor with shape [..., ab]
"""
x_shape = common_layers.shape_list(x)
a, b = x_shape[-2:]
return tf.reshape(x, x_shape[:-2] + [a * b])
@expert_utils.add_name_scope()
def combine_first_two_dimensions(x):
"""Reshape x so that the first two dimension become one.
Args:
x: a Tensor with shape [a, b, ...]
Returns:
a Tensor with shape [ab, ...]
"""
ret = tf.reshape(x, tf.concat([[-1], common_layers.shape_list(x)[2:]], 0))
old_shape = x.get_shape().dims
a, b = old_shape[:2]
new_shape = [a * b if a and b else None] + old_shape[2:]
ret.set_shape(new_shape)
return ret
@expert_utils.add_name_scope()
def split_heads(x, num_heads):
"""Split channels (dimension 2) into multiple heads (becomes dimension 1).
Args:
x: a Tensor with shape [batch, length, channels]
num_heads: an integer
Returns:
a Tensor with shape [batch, num_heads, length, channels / num_heads]
"""
return tf.transpose(split_last_dimension(x, num_heads), [0, 2, 1, 3])
@expert_utils.add_name_scope()
def split_heads_2d(x, num_heads):
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
Args:
x: a Tensor with shape [batch, height, width, channels]
num_heads: an integer
Returns:
a Tensor with shape [batch, num_heads, height, width, channels / num_heads]
"""
return tf.transpose(split_last_dimension(x, num_heads), [0, 3, 1, 2, 4])
def split_heads_nd(x, num_heads):
"""Split the depth dimension (last dimension) into multiple heads.
Args:
x: a [batch, d1, ..., dn, depth] tensor
num_heads: an integer
Returns:
a [batch, num_heads, d1, ..., dn, depth // num_heads]
"""
num_dimensions = len(common_layers.shape_list(x)) - 2
return tf.transpose(
split_last_dimension(x, num_heads), [0, num_dimensions + 1] +
list(range(1, num_dimensions + 1)) + [num_dimensions + 2])
@expert_utils.add_name_scope()
def combine_heads(x):
"""Inverse of split_heads.
Args:
x: a Tensor with shape [batch, num_heads, length, channels / num_heads]
Returns:
a Tensor with shape [batch, length, channels]
"""
return combine_last_two_dimensions(tf.transpose(x, [0, 2, 1, 3]))
@expert_utils.add_name_scope()
def combine_heads_2d(x):
"""Inverse of split_heads_2d.
Args:
x: a Tensor with shape
[batch, num_heads, height, width, channels / num_heads]
Returns:
a Tensor with shape [batch, height, width, channels]
"""
return combine_last_two_dimensions(tf.transpose(x, [0, 2, 3, 1, 4]))
def combine_heads_nd(x):
"""Inverse of split_heads_nd.
Args:
x: a [batch, num_heads, d1, ..., dn, depth // num_heads] tensor
Returns:
a [batch, d1, ...., dn, depth] tensor
"""
num_dimensions = len(common_layers.shape_list(x)) - 3
return combine_last_two_dimensions(
tf.transpose(x, [0] + list(range(2, num_dimensions + 2)) +
[1, num_dimensions + 2]))
def attention_image_summary(attn, image_shapes=None):
"""Compute color image summary.
Args:
attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
image_shapes: optional tuple of integer scalars.
If the query positions and memory positions represent the
pixels of flattened images, then pass in their dimensions:
(query_rows, query_cols, memory_rows, memory_cols).
If the query positions and memory positions represent the
pixels x channels of flattened images, then pass in their dimensions:
(query_rows, query_cols, query_channels,
memory_rows, memory_cols, memory_channels).
"""
attn = tf.cast(attn, tf.float32)
num_heads = common_layers.shape_list(attn)[1]
# [batch, query_length, memory_length, num_heads]
image = tf.transpose(attn, [0, 2, 3, 1])
image = tf.pow(image, 0.2) # for high-dynamic-range
# Each head will correspond to one of RGB.
# pad the heads to be a multiple of 3
image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, tf.mod(-num_heads, 3)]])
image = split_last_dimension(image, 3)
image = tf.reduce_max(image, 4)
if image_shapes is not None:
if len(image_shapes) == 4:
q_rows, q_cols, m_rows, m_cols = list(image_shapes)
image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3])
image = tf.transpose(image, [0, 1, 3, 2, 4, 5])
image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3])
else:
assert len(image_shapes) == 6
q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels = list(
image_shapes)
image = tf.reshape(
image,
[-1, q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels, 3])
image = tf.transpose(image, [0, 1, 4, 3, 2, 5, 6, 7])
image = tf.reshape(
image,
[-1, q_rows * m_rows * q_channnels, q_cols * m_cols * m_channels, 3])
tf.summary.image("attention", image, max_outputs=1)
def grouped_attention_multihead(query_antecedent,
memory_antecedent,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
num_groups,
memory_target_density=2.0,
multiplicative_overhead=1.25,
additive_overhead=8.0,
mask_right=False,
make_image_summary=True,
name=None):
"""Multi-head dot-product attention with sparsity.
For each attention head, the queries are partitioned into groups.
For each group, only a subset of the key-value pairs are considered.
The choices of groups are selected based on trained predictors of
the total attention given the group inclusion.
memory_target_density indicates the average how many groups in which
a key-value pair should participate.
We use auxiliary losses to ensure that each group contains roughly
the same number of queries and the same number of key-value pairs.
If for a given sequence, the actual number of queries/pairs sent to
an expert exceeds this target by a factor of more than
multiplicative_overhead, then the last ones are dropped. We use
this drop-last policy to avoid bleeding information backwards, which
is necessary when using this function with autoregressive
prediction.
Args:
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: a Tensor with shape [batch, length_m, channels]
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
num_groups: an integer
memory_target_density: a floating point scalar
multiplicative_overhead: a floating point scalar
additive_overhead: a floating point scalar
mask_right: a boolean
make_image_summary: a boolean
name: an optional string
Returns:
A Tensor with shape [batch, length_q, output_depth]
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
batch = common_layers.shape_list(query_antecedent)[0]
length_q = common_layers.shape_list(query_antecedent)[1]
length_kv = common_layers.shape_list(memory_antecedent)[1]
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
depth_qk = total_key_depth // num_heads
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
depth_v = total_value_depth // num_heads
with tf.variable_scope(
name, default_name="multihead_attention_sparse",
values=[query_antecedent, memory_antecedent]):
q = common_layers.dense(
query_antecedent, total_key_depth, use_bias=False, name="q_transform")
kv = common_layers.dense(
memory_antecedent,
total_key_depth + total_value_depth,
use_bias=False,
name="kv_transform")
q = split_heads(q, num_heads)
kv = split_heads(kv, num_heads)
# Make predictions about q_total and m_total.
# These are used to determine group inclusion.
# We will train these by auxiliary losses. We use stop_gradient here
# to keep these losses from back-propagating to the rest of the model.
# We add biases that help balance the usage of the experts.
q_pred = common_layers.dense(
tf.stop_gradient(query_antecedent),
num_heads * num_groups,
use_bias=False,
name="q_pred")
q_pred = split_heads(q_pred, num_heads)
q_bias = tf.get_variable("q_bias", [1, num_heads, 1, num_groups])
q_pred_biased = q_pred + q_bias
m_pred = common_layers.dense(
tf.stop_gradient(memory_antecedent),
num_heads * num_groups,
use_bias=False,
name="m_pred")
m_pred = split_heads(m_pred, num_heads)
m_bias = tf.get_variable("m_bias", [1, num_heads, 1, num_groups])
m_pred_biased = m_pred + m_bias
q *= depth_qk**-0.5
# q, kv, q_pred, m_pred are all [batch, heads, length_[q/m], ?]
# now reshape them all to [batch * heads, length, ?]
q = combine_first_two_dimensions(q)
kv = combine_first_two_dimensions(kv)
q_pred = combine_first_two_dimensions(q_pred)
m_pred = combine_first_two_dimensions(m_pred)
q_pred_biased = combine_first_two_dimensions(q_pred_biased)
m_pred_biased = combine_first_two_dimensions(m_pred_biased)
q_group = tf.argmax(q_pred_biased, axis=2)
q_requests = tf.one_hot(q_group, num_groups, axis=-1)
m_requests = tf.to_float(tf.greater(m_pred_biased, 0.0))
# include first memory position in all groups, to avoid division by zero.
m_requests = tf.maximum(
m_requests, tf.reshape(tf.one_hot([0], length_kv), [1, length_kv, 1]))
q_group_size = tf.reduce_sum(q_requests, 1)
m_group_size = tf.reduce_sum(m_requests, 1)
q_group_target_size = tf.to_float(length_q) / tf.to_float(num_groups)
m_group_target_size = (
tf.to_float(length_kv) * memory_target_density /
tf.to_float(num_groups))
capacity_q = tf.minimum(
length_q,
tf.to_int32(q_group_target_size * multiplicative_overhead +
additive_overhead))
capacity_m = tf.minimum(
length_kv,
tf.to_int32(m_group_target_size * multiplicative_overhead +
additive_overhead))
q_dispatcher = expert_utils.TruncatingDispatcher(q_requests, capacity_q)
m_dispatcher = expert_utils.TruncatingDispatcher(m_requests, capacity_m)
q_gates = q_dispatcher.gates()
m_gates = m_dispatcher.gates()
dispatched_q = q_dispatcher.dispatch(q)
dispatched_kv = m_dispatcher.dispatch(kv)
# dispatched_q: [batch * num_heads, num_groups, capacity_q, depth_qk]
# dispatched_kv:
# [batch * num_heads, num_groups, capacity_m, depth_qk + depth_v]
k, v = tf.split(dispatched_kv, [depth_qk, depth_v], axis=3)
logits = tf.matmul(dispatched_q, k, transpose_b=True)
bias = tf.expand_dims((m_dispatcher.nonpadding() - 1.0) * 1e9, 2)
if mask_right:
q_coordinate = tf.to_float(
tf.expand_dims(q_dispatcher.length_coordinate(), 3))
m_coordinate = tf.to_float(
tf.expand_dims(m_dispatcher.length_coordinate(), 2))
bias += tf.to_float(tf.greater(m_coordinate, q_coordinate)) * -1e9
logits += bias
log_weights = tf.nn.log_softmax(logits)
weights = tf.exp(log_weights)
# For each query, this is the log of the sum of the unnormalized weights.
q_total = tf.stop_gradient(logits[:, :, :, :1] - log_weights[:, :, :, :1])
# For each key, this is the sum of the normalized weights.
m_total = tf.expand_dims(
tf.reduce_sum(tf.stop_gradient(weights), axis=2), -1)
o = tf.matmul(weights, v)
o = q_dispatcher.combine(o)
o = tf.reshape(o, [batch, num_heads, length_q, depth_v])
o = combine_heads(o)
o = common_layers.dense(
o, output_depth, use_bias=False, name="output_transform")
m_total = m_dispatcher.combine(m_total)
q_total = q_dispatcher.combine(q_total)
q_total = tf.squeeze(q_total, -1)
m_total = tf.squeeze(m_total, -1)
# Compute summed m predictions for all groups
m_pred_used = tf.reduce_sum(tf.exp(m_pred) * m_dispatcher.gates(), axis=2)
q_pred_used = tf.reduce_sum(q_pred * q_dispatcher.gates(), axis=2)
epsilon = 1e-3
m_pred_used = tf.log(m_pred_used + epsilon)
m_total = tf.log(m_total + epsilon)
m_loss = tf.nn.l2_loss(m_total - m_pred_used)
q_loss = tf.nn.l2_loss(
(q_total - q_pred_used) * tf.reduce_sum(q_gates, axis=2))
q_loss /= tf.to_float(batch * length_q)
m_loss /= tf.to_float(batch * length_kv)
# We would like the query groups to be equal sized. The group
# size is discrete, so we need some trick here. We add a loss
# proportional to the product of the group size and the
# predictions for that group. This encourages the predictions to
# decrease for groups that are too big.
q_group_deviation = (q_group_size / q_group_target_size) - 1.0
q_balance_loss = tf.reduce_sum(
tf.reduce_mean(q_pred_biased, axis=1) *
q_group_deviation) / tf.to_float(batch)
m_group_deviation = (m_group_size / m_group_target_size) - 1.0
m_balance_loss = tf.reduce_sum(
tf.reduce_mean(m_pred_biased, axis=1) *
m_group_deviation) / tf.to_float(batch)
# The losses in this function only propagate back to variables
# defined in this function, and the losses outside of this
# function only propagate back to variables outside of this
# function. Assuming some kind of adaptive learning algorithm,
# it should not matter how much we scale the losses in this function.
# Still we scale them down a lot so that they should not show up
# much in the overall loss for the model.
extra_loss_multiplier = 1e-3
extra_loss = q_loss + m_loss + q_balance_loss + m_balance_loss
extra_loss *= extra_loss_multiplier
# Show a bunch of summaries.
if common_layers.should_generate_summaries() and make_image_summary:
tf.summary.histogram("q_group_size", q_group_size)
tf.summary.histogram("m_group_size", m_group_size)
tf.summary.scalar("q_loss", q_loss)
tf.summary.scalar("m_loss", m_loss)
tf.summary.scalar("q_balance_loss", q_balance_loss)
tf.summary.scalar("m_balance_loss", m_balance_loss)
tf.summary.histogram("m_pred_used", m_pred_used)
tf.summary.histogram("m_total", m_total)
tf.summary.histogram("q_pred_used", q_pred_used)
tf.summary.histogram("q_total", q_total)
if make_image_summary:
# image summaries are expensive.
# So we restrict them to head_num<4, query_position<512, batch_index=0.
trunc_heads = min(4, num_heads)
trunc_length_q = tf.minimum(length_q, 512)
# We recompute the attention for the first example, in an inefficient
# way - masking. This lets us show pretty pictures.
# [trunc_heads, length_q, group]
q_gates_trunc = q_gates[:trunc_heads, :trunc_length_q, :]
# [trunc_heads, length_kv, group]
m_gates_trunc = m_gates[:trunc_heads, :, :]
grouping_mask = tf.matmul(
q_gates_trunc, m_gates_trunc, transpose_b=True)
q_trunc = q[:trunc_heads, :trunc_length_q, :]
k_trunc = kv[:trunc_heads, :, :depth_qk]
logits_trunc = tf.matmul(q_trunc, k_trunc, transpose_b=True)
if mask_right:
band = common_layers.ones_matrix_band_part(trunc_length_q, length_kv,
-1, 0)
trunc_bias = tf.expand_dims((1.0 - band) * -1e9, 0)
logits_trunc += trunc_bias
att_trunc = tf.nn.softmax(logits_trunc)
mask_coverage = tf.reduce_sum(grouping_mask * att_trunc) / (
tf.to_float(trunc_length_q) * trunc_heads)
tf.summary.scalar("coverage", mask_coverage)
att_trunc_hdr = tf.pow(att_trunc, 0.2) # for high-dynamic-range
mask_channel = grouping_mask * tf.maximum(att_trunc_hdr, 0.3)
image = tf.stack([att_trunc_hdr, mask_channel, mask_channel], axis=3)
tf.summary.image("att", image, max_outputs=trunc_heads)
# show one group for each head.
att_per_group = tf.expand_dims(weights[:trunc_heads, 0, :, :], -1)
tf.summary.image(
"att_per_group_%d",
tf.pow(att_per_group, 0.2),
max_outputs=trunc_heads)
return o, extra_loss
def harden_attention_weights(weights, k, gumbel_noise_weight):
"""Make attention weights non-0 only on the top k ones."""
if gumbel_noise_weight > 0.:
gumbel_noise = -tf.log(-tf.log(tf.random_uniform(tf.shape(weights),
minval=1e-5,
maxval=1 - 1e-5)))
weights += gumbel_noise * gumbel_noise_weight
# Subtract the top-kth weight and zero-out all lower ones.
# Note that currently in case of numerical ties it will retain more
# than k elements. In the future, we may want to avoid this.
weights -= common_layers.top_kth_iterative(weights, k)
weights = tf.nn.relu(weights)
# Re-normalize the weights.
weights_sum = tf.reduce_sum(weights, axis=-1, keep_dims=True)
weights_sum = tf.maximum(weights_sum, 1e-6) # Avoid division by 0.
weights /= weights_sum
return weights
def dot_product_attention(q,
k,
v,
bias,
dropout_rate=0.0,
image_shapes=None,
name=None,
make_image_summary=True,
save_weights_to=None,
dropout_broadcast_dims=None,
activation_dtype=None,
weight_dtype=None,
hard_attention_k=0,
gumbel_noise_weight=0.0):
"""Dot-product attention.
Args:
q: Tensor with shape [..., length_q, depth_k].
k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
match with q.
v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
match with q.
bias: bias Tensor (see attention_bias())
dropout_rate: a float.
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
make_image_summary: True if you want an image summary.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
dropout_broadcast_dims: an optional list of integers less than rank of q.
Specifies in which dimensions to broadcast the dropout decisions.
activation_dtype: Used to define function activation dtype when using
mixed precision.
weight_dtype: The dtype weights are stored in when using mixed precision
hard_attention_k: integer, if > 0 triggers hard attention (picking top-k)
gumbel_noise_weight: if > 0, apply Gumbel noise with weight
`gumbel_noise_weight` before picking top-k. This is a no op if
hard_attention_k <= 0.
Returns:
Tensor with shape [..., length_q, depth_v].
"""
with tf.variable_scope(
name, default_name="dot_product_attention", values=[q, k, v]) as scope:
logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv]
if bias is not None:
bias = common_layers.cast_like(bias, logits)
logits += bias
# If logits are fp16, upcast before softmax
logits = maybe_upcast(logits, activation_dtype, weight_dtype)
weights = tf.nn.softmax(logits, name="attention_weights")
if hard_attention_k > 0:
weights = harden_attention_weights(weights, hard_attention_k,
gumbel_noise_weight)
weights = common_layers.cast_like(weights, q)
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
# Drop out attention links for each head.
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
def _generate_relative_positions_matrix(length_q, length_k,
max_relative_position,
cache=False):
"""Generates matrix of relative positions between inputs."""
if not cache:
if length_q == length_k:
range_vec_q = range_vec_k = tf.range(length_q)
else:
range_vec_k = tf.range(length_k)
range_vec_q = range_vec_k[-length_q:]
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
else:
distance_mat = tf.expand_dims(tf.range(-length_k+1, 1, 1), 0)
distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
max_relative_position)
# Shift values to be >= 0. Each integer still uniquely identifies a relative
# position difference.
final_mat = distance_mat_clipped + max_relative_position
return final_mat
def _generate_relative_positions_embeddings(length_q, length_k, depth,
max_relative_position, name,
cache=False):
"""Generates tensor of size [1 if cache else length_q, length_k, depth]."""
with tf.variable_scope(name):
relative_positions_matrix = _generate_relative_positions_matrix(
length_q, length_k, max_relative_position, cache=cache)
vocab_size = max_relative_position * 2 + 1
# Generates embedding for each relative position of dimension depth.
embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
embeddings = tf.gather(embeddings_table, relative_positions_matrix)
return embeddings
def _relative_attention_inner(x, y, z, transpose):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size, heads, length or 1, length or depth].
y: Tensor with shape [batch_size, heads, length or 1, depth].
z: Tensor with shape [length or 1, length, depth].
transpose: Whether to transpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size, heads, length, length or depth].
"""
batch_size = tf.shape(x)[0]
heads = x.get_shape().as_list()[1]
length = tf.shape(x)[2]
# xy_matmul is [batch_size, heads, length or 1, length or depth]
xy_matmul = tf.matmul(x, y, transpose_b=transpose)
# x_t is [length or 1, batch_size, heads, length or depth]
x_t = tf.transpose(x, [2, 0, 1, 3])
# x_t_r is [length or 1, batch_size * heads, length or depth]
x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
# x_tz_matmul is [length or 1, batch_size * heads, length or depth]
x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
# x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
# x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
return xy_matmul + x_tz_matmul_r_t
def dot_product_attention_relative(q,
k,
v,
bias,
max_relative_position,
dropout_rate=0.0,
image_shapes=None,
save_weights_to=None,
name=None,
make_image_summary=True,
cache=False,
allow_memory=False,
hard_attention_k=0,
gumbel_noise_weight=0.0):
"""Calculate relative position-aware dot-product self-attention.
The attention calculation is augmented with learned representations for the
relative position between each element in q and each element in k and v.
Args:
q: a Tensor with shape [batch, heads, length, depth].
k: a Tensor with shape [batch, heads, length, depth].
v: a Tensor with shape [batch, heads, length, depth].
bias: bias Tensor.
max_relative_position: an integer specifying the maximum distance between
inputs that unique position embeddings should be learned for.
dropout_rate: a floating point number.
image_shapes: optional tuple of integer scalars.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
name: an optional string.
make_image_summary: Whether to make an attention image summary.
cache: whether use cache mode
allow_memory: whether to assume that recurrent memory is in use. If True,
the length dimension of k/v/bias may be longer than the queries, and it is
assumed that the extra memory entries precede the non-memory entries.
hard_attention_k: integer, if > 0 triggers hard attention (picking top-k)
gumbel_noise_weight: if > 0, apply Gumbel noise with weight
`gumbel_noise_weight` before picking top-k. This is a no op if
hard_attention_k <= 0.
Returns:
A Tensor.
Raises:
ValueError: if max_relative_position is not > 0.
"""
if not max_relative_position:
raise ValueError("Max relative position (%s) should be > 0 when using "
"relative self attention." % (max_relative_position))
with tf.variable_scope(
name, default_name="dot_product_attention_relative",
values=[q, k, v]) as scope:
# This calculation only works for self attention.
# q, k and v must therefore have the same shape, unless memory is enabled.
if not cache and not allow_memory:
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
# Use separate embeddings suitable for keys and values.
depth = k.get_shape().as_list()[3]
length_k = common_layers.shape_list(k)[2]
length_q = common_layers.shape_list(q)[2] if allow_memory else length_k
relations_keys = _generate_relative_positions_embeddings(
length_q, length_k, depth, max_relative_position,
"relative_positions_keys", cache=cache)
relations_values = _generate_relative_positions_embeddings(
length_q, length_k, depth, max_relative_position,
"relative_positions_values", cache=cache)
# Compute self attention considering the relative position embeddings.
logits = _relative_attention_inner(q, k, relations_keys, True)
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if hard_attention_k > 0:
weights = harden_attention_weights(weights, hard_attention_k,
gumbel_noise_weight)
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
if (not tf.get_variable_scope().reuse and
common_layers.should_generate_summaries() and
make_image_summary):
attention_image_summary(weights, image_shapes)
return _relative_attention_inner(weights, v, relations_values, False)
def _relative_position_to_absolute_position_masked(x):
"""Helper to dot_product_self_attention_relative_v2.
Rearrange an attention logits or weights Tensor.
The dimensions of the input represent:
[batch, heads, query_position, memory_position - query_position + length - 1]
The dimensions of the output represent:
[batch, heads, query_position, memory_position]
Only works with masked_attention. Undefined behavior for regions of the
input where memory_position > query_position.
Args:
x: a Tensor with shape [batch, heads, length, length]
Returns:
a Tensor with shape [batch, heads, length, length]
"""
batch, heads, length, _ = common_layers.shape_list(x)
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]])
x = tf.reshape(x, [batch, heads, 1 + length, length])
x = tf.slice(x, [0, 0, 1, 0], [-1, -1, -1, -1])
return x
def _absolute_position_to_relative_position_masked(x):
"""Helper to dot_product_self_attention_relative_v2.
Rearrange an attention logits or weights Tensor.
The dimensions of the input represent:
[batch, heads, query_position, memory_position]
The dimensions of the output represent:
[batch, heads, query_position, memory_position - query_position + length - 1]
Only works with masked_attention. Undefined behavior for regions of the
input where memory_position > query_position.
Args:
x: a Tensor with shape [batch, heads, length, length]
Returns:
a Tensor with shape [batch, heads, length, length]
"""
batch, heads, length, _ = common_layers.shape_list(x)
x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]])
x = tf.reshape(x, [batch, heads, length, length + 1])
x = tf.slice(x, [0, 0, 0, 1], [batch, heads, length, length])
return x
def get_relative_embeddings_left(max_relative_position, length, depth,
num_heads, heads_share_relative_embedding,
name):
"""Instantiate or retrieve relative embeddings, sliced according to length.
Use for masked case where the relative attention is only looking left.
Args:
max_relative_position: an Integer for the number of entries in the relative
embedding, which corresponds to the max relative distance that is
considered.
length: an Integer, specifies the length of the input sequence for which
this relative embedding is retrieved for.
depth: an Integer, specifies the depth for relative embeddings.
num_heads: an Integer, specifies the number of heads.
heads_share_relative_embedding: a Boolean specifying if the relative
embedding is shared across heads.
name: a string giving the name of the embedding variables.
Returns:
a Tensor with shape [length, depth]
"""
initializer_stddev = depth**-0.5
if heads_share_relative_embedding:
embedding_shape = (max_relative_position, depth)
else:
embedding_shape = (num_heads, max_relative_position, depth)
relative_embeddings = tf.get_variable(
name=name, shape=embedding_shape,
initializer=tf.random_normal_initializer(stddev=initializer_stddev))
# Pad first before slice to avoid using tf.cond.
pad_length = tf.maximum(length - max_relative_position, 0)
start_slice_position = tf.maximum(max_relative_position - length, 0)
if heads_share_relative_embedding:
padded_relative_embeddings = tf.pad(
relative_embeddings,
[[pad_length, 0], [0, 0]])
used_relative_embeddings = tf.slice(
padded_relative_embeddings,
[start_slice_position, 0], [length, -1])
else:
padded_relative_embeddings = tf.pad(
relative_embeddings,
[[0, 0], [pad_length, 0], [0, 0]])
used_relative_embeddings = tf.slice(
padded_relative_embeddings,
[0, start_slice_position, 0], [-1, length, -1])
return used_relative_embeddings
def dot_product_self_attention_relative_v2(q,
k,
v,
bias,
max_relative_position=None,
dropout_rate=0.0,
image_shapes=None,
save_weights_to=None,
name=None,
make_image_summary=True,
dropout_broadcast_dims=None,
heads_share_relative_embedding=False,
add_relative_to_values=False):
"""Calculate relative position-aware dot-product self-attention.
Only works for masked self-attention (no looking forward).
The attention calculation is augmented with learned representations for the
relative position between each element in q and each element in k and v.
Args:
q: a Tensor with shape [batch, heads, length, depth].
k: a Tensor with shape [batch, heads, length, depth].
v: a Tensor with shape [batch, heads, length, depth].
bias: bias Tensor.
max_relative_position: an integer indicating the maximum relative distance
to look back - changing this invalidates checkpoints
dropout_rate: a floating point number.
image_shapes: optional tuple of integer scalars.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
name: an optional string.
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
heads_share_relative_embedding: a boolean indicating wheather to share
relative embeddings between attention heads.
add_relative_to_values: a boolean for whether to add relative component to
values.
Returns:
A Tensor.
Raises:
ValueError: if max_relative_position is not > 0.
"""
if not max_relative_position:
raise ValueError("Max relative position (%s) should be > 0 when using "
"relative self attention." % (max_relative_position))
with tf.variable_scope(
name,
default_name="dot_product_self_attention_relative_v2",
values=[q, k, v]) as scope:
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
# (Except v can have different depth.)
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape()[:-1].assert_is_compatible_with(v.get_shape()[:-1])
# Use separate embeddings suitable for keys and values.
_, num_heads, length, depth_k = common_layers.shape_list(k)
# [batch, num_heads, query_length, memory_length]
logits = tf.matmul(q, k, transpose_b=True)
key_relative_embeddings = get_relative_embeddings_left(
max_relative_position, length, depth_k, num_heads,
heads_share_relative_embedding, "key_relative_embeddings")
rel_logits = matmul_with_relative_keys(q, key_relative_embeddings,
heads_share_relative_embedding)
rel_logits = _relative_position_to_absolute_position_masked(rel_logits)
logits += rel_logits
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
# Dropping out the attention links for each of the heads.
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(weights, image_shapes)
output = tf.matmul(weights, v)
if add_relative_to_values:
# [batch, num_heads, query_length, memory_length]
relative_weights = _absolute_position_to_relative_position_masked(weights)
depth_v = common_layers.shape_list(v)[3]
value_relative_embeddings = get_relative_embeddings_left(
max_relative_position, length, depth_v, num_heads,
heads_share_relative_embedding, "value_relative_embeddings")
output += matmul_with_relative_values(
relative_weights, value_relative_embeddings,
heads_share_relative_embedding)
return output
def _absolute_position_to_relative_position_unmasked(x):
"""Helper function for dot_product_unmasked_self_attention_relative_v2.
Rearrange an attention logits or weights Tensor.
The dimensions of the input represent:
[batch, heads, query_position, memory_position]
The dimensions of the output represent:
[batch, heads, query_position, memory_position - query_position + length - 1]
Only works with unmasked_attention.
Args:
x: a Tensor with shape [batch, heads, length, length]
Returns:
a Tensor with shape [batch, heads, length, 2*length-1]
"""
batch, heads, length, _ = common_layers.shape_list(x)
# padd along column
x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, length-1]])
x_flat = tf.reshape(x, [batch, heads, length**2 + length*(length -1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = tf.pad(x_flat, [[0, 0], [0, 0], [length, 0]])
x = tf.reshape(x_flat, [batch, heads, length, 2*length])
x = tf.slice(x, [0, 0, 0, 1], [batch, heads, length,
2*length -1])
return x
def get_relative_embeddings_left_right(max_relative_position, length, depth,
num_heads,
heads_share_relative_embedding,
name):
"""Instantiate or retrieve relative embeddings, sliced according to length.
Use for unmasked case where the relative attention looks both left and right.
Args:
max_relative_position: an Integer for the number of entries in the relative
embedding, which corresponds to the max relative distance that is
considered.
length: an Integer, specifies the length of the input sequence for which
this relative embedding is retrieved for.
depth: an Integer, specifies the depth for relative embeddings.
num_heads: an Integer, specifies the number of heads.
heads_share_relative_embedding: a Boolean specifying if the relative
embedding is shared across heads.
name: a string giving the name of the embedding variables.
Returns:
a Tensor with shape [length, depth]
"""
initializer_stddev = depth**-0.5
max_relative_position_unmasked = 2 * max_relative_position - 1
if heads_share_relative_embedding:
embedding_shape = (max_relative_position_unmasked, depth)
else:
embedding_shape = (num_heads, max_relative_position_unmasked, depth)
relative_embeddings = tf.get_variable(
name=name, shape=embedding_shape,
initializer=tf.random_normal_initializer(stddev=initializer_stddev))
# Pad first before slice to avoid using tf.cond.
pad_length = tf.maximum(length - max_relative_position, 0)
slice_start_position = tf.maximum(max_relative_position-length, 0)
if heads_share_relative_embedding:
padded_relative_embeddings = tf.pad(
relative_embeddings,
[[pad_length, pad_length], [0, 0]])
used_relative_embeddings = tf.slice(
padded_relative_embeddings,
[slice_start_position, 0], [2 * length - 1, -1])
else:
padded_relative_embeddings = tf.pad(
relative_embeddings,
[[0, 0], [pad_length, pad_length], [0, 0]])
used_relative_embeddings = tf.slice(
padded_relative_embeddings,
[0, slice_start_position, 0], [-1, 2 * length - 1, -1])
return used_relative_embeddings
def dot_product_unmasked_self_attention_relative_v2(
q, k, v, bias, max_relative_position=None, dropout_rate=0.0,
image_shapes=None, save_weights_to=None, name=None, make_image_summary=True,
dropout_broadcast_dims=None, heads_share_relative_embedding=False,
add_relative_to_values=False):
"""Calculate relative position-aware dot-product self-attention.
The attention calculation is augmented with learned representations for the
relative position between each element in q and each element in k and v.
Args:
q: a Tensor with shape [batch, heads, length, depth].
k: a Tensor with shape [batch, heads, length, depth].
v: a Tensor with shape [batch, heads, length, depth].
bias: bias Tensor.
max_relative_position: an integer the max relative embedding considered.
Changing this invalidates checkpoints.
dropout_rate: a floating point number.
image_shapes: optional tuple of integer scalars.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
name: an optional string.
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
heads_share_relative_embedding: a boolean indicating wheather to share
relative embeddings between attention heads.
add_relative_to_values: a boolean for whether to add relative component to
values.
Returns:
A Tensor.
Raises:
ValueError: if max_relative_position is not > 0.
"""
if not max_relative_position:
raise ValueError("Max relative position (%s) should be > 0 when using "
"relative self attention." % (max_relative_position))
with tf.variable_scope(
name,
default_name="dot_product_unmasked_self_attention_relative_v2",
values=[q, k, v]) as scope:
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
# [batch, num_heads, query_length, memory_length]
logits = tf.matmul(q, k, transpose_b=True)
length = common_layers.shape_list(q)[2]
k_shape = common_layers.shape_list(k)
num_heads = k_shape[1]
depth_k = k_shape[-1]
key_relative_embeddings = get_relative_embeddings_left_right(
max_relative_position, length, depth_k, num_heads,
heads_share_relative_embedding,
"key_relative_embeddings")
unmasked_rel_logits = matmul_with_relative_keys(
q, key_relative_embeddings, heads_share_relative_embedding)
unmasked_rel_logits = _relative_position_to_absolute_position_unmasked(
unmasked_rel_logits)
logits += unmasked_rel_logits
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if save_weights_to is not None:
save_weights_to[scope.name] = weights
save_weights_to[scope.name + "/logits"] = logits
# dropping out the attention links for each of the heads
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
# relative_weights.set_shape([None, None, None, max_length])
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(weights, image_shapes)
ret = tf.matmul(weights, v)
if add_relative_to_values:
# Adds the contribution of the weighted relative embeddings to the values.
# [batch, num_heads, query_length, 2*memory_length-1]
relative_weights = _absolute_position_to_relative_position_unmasked(
weights)
depth_v = common_layers.shape_list(v)[3]
value_relative_embeddings = get_relative_embeddings_left_right(
max_relative_position, length, depth_v, num_heads,
heads_share_relative_embedding, "value_relative_embeddings")
ret += matmul_with_relative_values(
relative_weights, value_relative_embeddings,
heads_share_relative_embedding)
return ret
def _matmul_with_relative_keys_2d(x, y, heads_share_relative_embedding):
"""Helper function for dot_product_unmasked_self_attention_relative_2d."""
if heads_share_relative_embedding:
ret = tf.einsum("bhxyd,md->bhxym", x, y)
else:
ret = tf.einsum("bhxyd,hmd->bhxym", x, y)
return ret
def dot_product_unmasked_self_attention_relative_2d(
q, k, v, bias, max_relative_position=None, dropout_rate=0.0,
image_shapes=None, name=None, make_image_summary=True,
dropout_broadcast_dims=None, heads_share_relative_embedding=False,
add_relative_to_values=False):
"""Calculate relative position unmasked dot-product self-attention 2d.
The attention calculation is augmented with learned representations for the
relative position between each element in q and each element in k and v in
height and width dimensions. for query index (i,j) and key index (l, m),
the logit is q_i k_j^T + q_i rh_{l-i}^T + q_i rw_{m-j}^T, where rh and ry are
the set of relative embeddings in height and width spatial dimensions,
respectively.
Args:
q: a Tensor with shape [batch, heads, height, width, depth].
k: a Tensor with shape [batch, heads, height, width, depth].
v: a Tensor with shape [batch, heads, height, width, depth].
bias: bias Tensor.
max_relative_position: an integer the max relative embedding considered.
Changing this invalidates checkpoints.
dropout_rate: a floating point number.
image_shapes: optional tuple of integer scalars.
name: an optional string.
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
heads_share_relative_embedding: a boolean indicating wheather to share
relative embeddings between attention heads.
add_relative_to_values: a boolean for adding relative embeddings to values.
Returns:
[batch, heads, height, width, depth] tensor, the output of attention.
height_key_relative_embeddings: a 3d or 2d tensor, depending on head sharing
settings, which are the relative embeddings for height.
width_key_relative_embeddings: a 3d or 2d tensor, depending on head sharing
settings, which are the relative embeddings for width.
Raises:
ValueError: if max_relative_position is not > 0.
"""
if not max_relative_position:
raise ValueError("Max relative position (%s) should be > 0 when using "
"relative self attention." % (max_relative_position))
if add_relative_to_values:
raise ValueError("Adding relative embeddings to values is not implemented")
with tf.variable_scope(
name,
default_name="dot_product_self_attention_relative_v2",
values=[q, k, v]):
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape()[:-1].assert_is_compatible_with(v.get_shape()[:-1])
(height, width) = (common_layers.shape_list(q)[2],
common_layers.shape_list(q)[3])
k_shape = common_layers.shape_list(k)
num_heads = k_shape[1]
depth_k = k_shape[-1]
depth_v = common_layers.shape_list(v)[-1]
# flatten height width
flatten_hw = lambda x, d: tf.reshape(x, [-1, num_heads, height*width, d])
# [batch, num_heads, query_length, memory_length]
logits = tf.matmul(flatten_hw(q, depth_k), flatten_hw(k, depth_k),
transpose_b=True)
def _compute_2d_relative_logits(
query, key_relative_embeddings, height, width,
heads_share_relative_embedding, transpose_mask):
"""compute relative logits."""
unmasked_rel_logits = _matmul_with_relative_keys_2d(
query, key_relative_embeddings, heads_share_relative_embedding)
# collapse height and heads
unmasked_rel_logits = tf.reshape(unmasked_rel_logits,
[-1, num_heads*height, width,
2*width-1])
unmasked_rel_logits = (
_relative_position_to_absolute_position_unmasked(
unmasked_rel_logits))
# shape it back for tiling
unmasked_rel_logits = tf.reshape(
unmasked_rel_logits, [-1, num_heads, height, width, width])
# tiling it height times
unmasked_rel_logits = tf.expand_dims(
unmasked_rel_logits, axis=3)
unmasked_rel_logits = tf.tile(unmasked_rel_logits,
[1, 1, 1, height, 1, 1])
# bringing it to the right shape for adding to the logits.
unmasked_rel_logits = tf.transpose(unmasked_rel_logits, transpose_mask)
unmasked_rel_logits = tf.reshape(unmasked_rel_logits,
[-1, num_heads, height*width,
height*width])
return unmasked_rel_logits
# Relative logits in width dimension first.
width_key_relative_embeddings = get_relative_embeddings_left_right(
max_relative_position, width, depth_k, num_heads,
heads_share_relative_embedding,
"width_key_relative_embeddings")
# [batch, heads, height, 2*width-1, 2*width-1]
width_unmasked_rel_logits = _compute_2d_relative_logits(
q, width_key_relative_embeddings, height, width,
heads_share_relative_embedding, [0, 1, 2, 4, 3, 5])
logits += width_unmasked_rel_logits
# Relative logits in height dimension next. For ease, we transpose
# height and width and repeat the above steps, and transpose to eventually
# put the logits in their right positions.
# [batch, heads, height, 2*height-1, 2*width-1]
height_key_relative_embeddings = get_relative_embeddings_left_right(
max_relative_position, height, depth_k, num_heads,
heads_share_relative_embedding,
"height_key_relative_embeddings")
height_unmasked_rel_logits = _compute_2d_relative_logits(
tf.transpose(q, [0, 1, 3, 2, 4]),
height_key_relative_embeddings,
width,
height,
heads_share_relative_embedding, [0, 1, 4, 2, 5, 3])
logits += height_unmasked_rel_logits
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
# dropping out the attention links for each of the heads
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(weights, image_shapes)
ret = tf.matmul(weights, flatten_hw(v, depth_v))
# reshape back the same spatial dimensions as q
return (
tf.reshape(ret, [-1, num_heads, height, width, depth_v]),
height_key_relative_embeddings,
width_key_relative_embeddings)
def _split_along_width(x_left_right_blocks):
"""Helper function for local 2d attention.
Takes a tensor of [batch, heads, num_h_blocks, num_w_blocks,
height, width, depth] and returns two tensors which contain every alternate
position along the width
Args:
x_left_right_blocks: A [batch, num_h_blocks, num_w_blocks,
height, width, depth] tensor
Returns:
x_left_blocks, x_right_blocks: two [batch, num_h_blocks,
(num_w_blocks-2)/2, height, width,
depth] tensors
"""
(_, x_num_h_blocks, x_num_outer_w_blocks, x_memory_flange_h,
x_memory_flange_w, depth) = common_layers.shape_list(x_left_right_blocks)
x_num_w_blocks = (x_num_outer_w_blocks-1)//2
# get it ready for splitting the left and right memory blocks
x_left_right_blocks = tf.reshape(x_left_right_blocks,
[-1,
x_num_h_blocks,
x_num_outer_w_blocks//2, 2,
x_memory_flange_h,
x_memory_flange_w, depth])
x_left_blocks, x_right_blocks = tf.split(x_left_right_blocks,
num_or_size_splits=2, axis=3)
x_left_blocks = tf.squeeze(x_left_blocks, axis=3)
x_right_blocks = tf.squeeze(x_right_blocks, axis=3)
x_left_blocks = tf.slice(x_left_blocks, [0, 0, 0, 0, 0, 0],
[-1, -1, x_num_w_blocks, -1, -1, -1])
x_right_blocks = tf.slice(x_right_blocks, [0, 0, 1, 0, 0, 0],
[-1, -1, x_num_w_blocks, -1, -1, -1])
return x_left_blocks, x_right_blocks
def _get_left_right_blocks(x):
"""Helper function. Assumes that memory_flange is half of query sizes.
This function splits the tensor of width 'n' into two halves, where the
first half gets the width indices 0, 2, 4.. and the second half gets the
width indices 3, 5, ... We also fuse two blocks along the h dimension.
Args:
x: a 6-d tensor.
Returns:
x_left_blocks, x_right_blocks: Two 6-d tensors
"""
(_, x_num_outer_h_blocks, x_num_outer_w_blocks, x_memory_flange_h,
x_memory_flange_w, depth) = common_layers.shape_list(x)
x_left_right_blocks = tf.slice(x,
[0, 1, 0, 0, 0, 0],
[-1, x_num_outer_h_blocks-2, -1, -1,
-1, -1])
num_blocks_h = (x_num_outer_h_blocks-2)//2
x_left_right_blocks = tf.reshape(x_left_right_blocks,
[-1,
num_blocks_h,
2, x_num_outer_w_blocks,
x_memory_flange_h,
x_memory_flange_w, depth])
x_left_right_blocks = tf.transpose(x_left_right_blocks,
[0, 1, 3, 2, 4, 5, 6])
x_left_right_blocks = tf.reshape(x_left_right_blocks,
[-1, num_blocks_h,
x_num_outer_w_blocks, 2*x_memory_flange_h,
x_memory_flange_w, depth])
# get it ready for splitting the left and right memory blocks
x_left_blocks, x_right_blocks = _split_along_width(x_left_right_blocks)
return x_left_blocks, x_right_blocks
# return x_left_right_blocks
def _extract_blocks(x, block_h, block_w):
"""Helper function for local 2d attention.
Args:
x: a [batch, height, width, depth] tensor
block_h: An integer. block height
block_w: An inteter. block width
Returns:
a [batch, num_heads, height/block_h, width/block_w, depth] tensor
"""
(_, height, width, depth) = common_layers.shape_list(x)
assert height % block_h == 0
assert width % block_w == 0
x = tf.reshape(x, [-1, height//block_h, block_h,
width//block_w, block_w, depth])
return tf.transpose(x, [0, 1, 3, 2, 4, 5])
def get_2d_local_memory(x, query_shape, memory_flange):
"""Stitches together the local 2d memory blocks.
Args:
x: a [batch, height, width, depth tensor]
query_shape: 2-d integer list of query shape
memory_flange: 2-d integer list of memory flanges
Returns:
x: A [batch, num_h_blocks, num_w_blocks,
query_shape[0]+2*memory_flange[0],query_shape[1]+2*memory_flange[1]]
tensor.
"""
(_, height, width, depth_x) = common_layers.shape_list(x)
x_center_blocks = _extract_blocks(x, query_shape[0], query_shape[1])
# add extra padding to x so that we can extract the memory region
# around the center
paddings = [[0, 0], [memory_flange[0], memory_flange[0]],
[memory_flange[1], memory_flange[1]], [0, 0]]
padded_x = tf.pad(x, paddings)
padded_x.set_shape([None, height+2*memory_flange[0],
width+2*memory_flange[1], depth_x])
x_outer_memory_blocks = _extract_blocks(padded_x,
memory_flange[0], memory_flange[1])
# We'll extract left and right memory blocks, top and bottom memory blocks,
# and then the corner memory blocks
# Each of these after will have shape
# [batch, num_h_blocks, num_w_blocks, query_shape[0],
# memory_flange[1], depth]
x_left_blocks, x_right_blocks = _get_left_right_blocks(
x_outer_memory_blocks)
t_hw_block = lambda x: tf.transpose(x, [0, 2, 1, 4, 3, 5])
# now to get top and bottom blocks, we should just transpose the outer
# blocks, call the same function and transpose back to get shape
# [batch, num_h_blocks, num_w_blocks, memory_flange[0],
# query_shape[1], depth]
x_top_center_blocks, x_bottom_center_blocks = (
map(t_hw_block, _get_left_right_blocks(
t_hw_block(x_outer_memory_blocks))))
# now to get the corner blocks
x_left_corner_blocks, x_right_corner_blocks = _split_along_width(
x_outer_memory_blocks)
# now to extract top and bottom for both k and v
# we need to transpose because _split_along_width separates along
# the width
# each of these should have shape [batch, num_h_blocks,
# num_w_blocks, memory_flange[0], memory_flange[1], depth]
t_hw = lambda x: tf.transpose(x, [0, 2, 1, 3, 4, 5])
x_top_left_corner_blocks, x_bottom_left_corner_blocks = (
map(t_hw, _split_along_width(t_hw(x_left_corner_blocks))))
x_top_right_corner_blocks, x_bottom_right_corner_blocks = (
map(t_hw, _split_along_width(t_hw(x_right_corner_blocks))))
# The memory is top_left top_center top_right
# left_center middle right_center
# bottom_left bottom_center bottom_right
# Assembling the above row by row
# first [x_top_left, x_top, x_top_right]
# to get [batch, num_h_blocks, num_w_blocks, memory_flange[0],
# query_shape[1]+2*memory_flange[1], depth]
# then [x_left, x_center, x_right]
# then [x_bottom_left, x_bottom, x_bottom_right]
x_top_memory = tf.concat(
[x_top_left_corner_blocks,
x_top_center_blocks,
x_top_right_corner_blocks], axis=4)
x_middle_memory = tf.concat(
[x_left_blocks, x_center_blocks, x_right_blocks], axis=4)
x_bottom_memory = tf.concat(
[x_bottom_left_corner_blocks,
x_bottom_center_blocks,
x_bottom_right_corner_blocks], axis=4)
# concat along height
x = tf.concat([x_top_memory, x_middle_memory, x_bottom_memory], axis=3)
return x
def get_2d_local_memory_v2(x, query_shape, memory_flange):
"""Gathering memory blocks around query blocks. flange is half of query .
Only works if memory flanges are half of query sizes.
Args:
x: a [batch, height, width, depth tensor]
query_shape: 2-d integer list of query shape
memory_flange: 2-d integer list of memory flanges
Returns:
x: A [batch, num_h_blocks, num_w_blocks,
query_shape[0]+2*memory_flange[0],query_shape[1]+2*memory_flange[1]]
tensor.
"""
(_, height, width, depth_x) = common_layers.shape_list(x)
# add extra padding to x so that we can extract the memory region
# around the center
paddings = [[0, 0], [memory_flange[0], memory_flange[0]],
[memory_flange[1], memory_flange[1]], [0, 0]]
padded_x = tf.pad(x, paddings)
padded_x.set_shape([None, height+2*memory_flange[0],
width+2*memory_flange[1], depth_x])
num_h_memory_blocks = height//query_shape[0] + 1
num_w_memory_blocks = width//query_shape[1] + 1
x_memory_blocks = _extract_blocks(padded_x,
query_shape[0], query_shape[1])
x_width_blocks = tf.split(x_memory_blocks, num_w_memory_blocks,
2)
x_left_width = tf.concat(x_width_blocks[:num_w_memory_blocks - 1], axis=2)
x_right_width = tf.concat(x_width_blocks[1:], axis=2)
x_memory_blocks = tf.concat([x_left_width, x_right_width], axis=4)
x_height_blocks = tf.split(x_memory_blocks, num_h_memory_blocks, 1)
x_top_height = tf.concat(x_height_blocks[:num_h_memory_blocks - 1], axis=1)
x_bottom_height = tf.concat(x_height_blocks[1:], axis=1)
x = tf.concat([x_top_height, x_bottom_height], axis=3)
return x
def dot_product_unmasked_attention_local_2d_tpu(
q, k, v, bias, max_relative_position=None, query_shape=(8, 8),
dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=False,
dropout_broadcast_dims=None):
"""Calculate unmasked dot-product local self-attention 2d on tpu.
Args:
q: a Tensor with shape [batch, heads, height, width, depth].
k: a Tensor with shape [batch, heads, height, width, depth].
v: a Tensor with shape [batch, heads, height, width, depth].
bias: bias Tensor.
max_relative_position: an integer the max relative embedding considered.
Changing this invalidates checkpoints.
query_shape: a two tuple indicating query shape
dropout_rate: a floating point number.
image_shapes: optional tuple of integer scalars.
name: an optional string.
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
Returns:
[batch, heads, height, width, depth] tensor, the output of attention.
"""
if max_relative_position:
raise ValueError("Relative local 2d attention not implemented")
with tf.variable_scope(
name,
default_name="dot_product_unmasked_attention_local_2d_tpu",
values=[q, k, v]):
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape().assert_is_compatible_with(v.get_shape())
orig_q_shape = common_layers.shape_list(q)
# Pad query, key, value to ensure multiple of corresponding lengths.
memory_flange = [int(query_shape[0]//2), int(query_shape[1]//2)]
q = pad_to_multiple_2d(q, query_shape)
k = pad_to_multiple_2d(k, query_shape)
v = pad_to_multiple_2d(v, query_shape)
q_shape = common_layers.shape_list(q)
(height, width) = (q_shape[2],
q_shape[3])
_, num_heads, height, width, depth_k = common_layers.shape_list(k)
depth_v = common_layers.shape_list(v)[-1]
num_h_blocks = height//query_shape[0]
num_w_blocks = width//query_shape[1]
# Extract center queries, keys, and values
q = tf.reshape(q, [-1, height, width, depth_k])
queries = _extract_blocks(
q, query_shape[0], query_shape[1])
k = tf.reshape(k, [-1, height, width, depth_k])
keys = get_2d_local_memory_v2(
k, query_shape, memory_flange)
v = tf.reshape(v, [-1, height, width, depth_v])
values = get_2d_local_memory_v2(
v, query_shape, memory_flange)
memory_h = query_shape[0] + 2*memory_flange[0]
memory_w = query_shape[1] + 2*memory_flange[1]
queries = tf.reshape(queries, [-1, num_heads, num_h_blocks, num_w_blocks,
query_shape[0]*query_shape[1], depth_k])
keys = tf.reshape(keys, [-1, num_heads, num_h_blocks, num_w_blocks,
memory_h*memory_w, depth_k])
values = tf.reshape(values, [-1, num_heads, num_h_blocks, num_w_blocks,
memory_h*memory_w, depth_v])
logits = tf.matmul(queries, keys, transpose_b=True)
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
# Dropping out the attention links for each of the heads
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(weights, image_shapes)
ret = tf.matmul(weights, values)
# we need to get it back to shape [batch, heads, height, width]
ret = tf.reshape(ret, [-1, num_heads, num_h_blocks, num_w_blocks,
query_shape[0], query_shape[1], depth_v])
ret = tf.transpose(ret, [0, 1, 2, 4, 3, 5, 6])
ret = tf.reshape(ret, [-1, num_heads, num_h_blocks*query_shape[0],
num_w_blocks*query_shape[1], depth_v])
# slice if padding was introduced
ret = tf.slice(ret, [0, 0, 0, 0, 0], [-1, -1, orig_q_shape[2],
orig_q_shape[3], -1])
return ret
def dot_product_unmasked_attention_local_2d_tpu_simple(
x, bias, total_key_depth, total_value_depth, num_heads,
query_shape=(8, 8),
dropout_rate=0.0, image_shapes=None, make_image_summary=False,
dropout_broadcast_dims=None):
"""Calculate simple unmasked dot-product local self-attention 2d on tpu.
The query, key, and value blocks are the same. We do not do a second linear
transformation after computing the values
Args:
x: a Tensor with shape [batch, height, width, depth].
bias: bias Tensor.
total_key_depth: the dimensions of the keys
total_value_depth: the dimensions of the values
num_heads: number of heads
query_shape: a two tuple indicating query shape
dropout_rate: a floating point number.
image_shapes: optional tuple of integer scalars.
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
Returns:
ret: [batch, height, width, total_value_depth] tensor,
the output of attention.
q: [batch, height, width, total_key_depth] query tensor
k: [batch, height, width, total_key_depth] key tensor
v: [batch, height, width, total_value_depth] value tensor
"""
# This calculation only works for self attention.
# q, k and v must therefore have the same shape.
orig_x_shape = common_layers.shape_list(x)
# Pad query, key, value to ensure multiple of corresponding lengths if
# necessary
is_padded = False
if (orig_x_shape[1]%query_shape[0]) != 0 or (
orig_x_shape[2]%query_shape[1]) != 0:
x = pad_to_multiple_2d(x, query_shape)
is_padded = True
_, height, width, depth = common_layers.shape_list(x)
assert depth%num_heads == 0
num_h_blocks = height//query_shape[0]
num_w_blocks = width//query_shape[1]
# Extract center queries, keys, and values
x_blocks = _extract_blocks(x, query_shape[0], query_shape[1])
x_blocks = tf.reshape(x_blocks, [-1, query_shape[0]*query_shape[1], depth])
q, k, v = compute_qkv(x_blocks, None, total_key_depth, total_value_depth)
hsplit = lambda x: split_heads(x, num_heads)
q, k, v = map(hsplit, [q, k, v])
logits = tf.matmul(q, k, transpose_b=True)
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
# Dropping out the attention links for each of the heads
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(weights, image_shapes)
output = tf.matmul(weights, v)
output = combine_heads(output)
# we need to get it back to shape [batch, height, width]
ret = tf.reshape(output, [-1, num_h_blocks, num_w_blocks,
query_shape[0], query_shape[1], total_value_depth])
ret = tf.transpose(ret, [0, 1, 3, 2, 4, 5])
ret = tf.reshape(ret, [-1, num_h_blocks*query_shape[0],
num_w_blocks*query_shape[1], total_value_depth])
# slice if padding was introduced
if is_padded:
ret = tf.slice(ret, [0, 0, 0, 0], [-1, orig_x_shape[1],
orig_x_shape[2], -1])
return ret, q, k, v
def masked_within_block_local_attention_1d(q, k, v, block_length=64, name=None):
"""Attention to the source and a neighborhood to the left within a block.
The sequence is divided into blocks of length block_length. Attention for a
given query position can only see memory positions less than or equal to the
query position in the corresponding block.
Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth_v]
"""
with tf.variable_scope(
name, default_name="within_local_attention_1d", values=[q, k, v]):
batch, heads, length, depth_k = common_layers.shape_list(q)
depth_v = common_layers.shape_list(v)[-1]
if isinstance(block_length, tf.Tensor):
const = contrib.util().constant_value(block_length)
if const is not None:
block_length = int(const)
# Pad query, key, value to ensure multiple of block length.
original_length = length
padding_size = tf.mod(-length, block_length)
length += padding_size
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
q = tf.pad(q, padding)
k = tf.pad(k, padding)
v = tf.pad(v, padding)
# Compute attention for all subsequent query blocks.
num_blocks = tf.div(length, block_length)
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
# [batch, heads, num_blocks, block_length, block_length]
attention = tf.matmul(q, k, transpose_b=True)
attention += tf.reshape(attention_bias_lower_triangle(block_length),
[1, 1, 1, block_length, block_length])
attention = tf.nn.softmax(attention)
# [batch, heads, num_blocks, block_length, depth_v]
output = tf.matmul(attention, v)
output = tf.reshape(output, [batch, heads, -1, depth_v])
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output.set_shape([None if isinstance(dim, tf.Tensor) else dim for dim in
(batch, heads, length, depth_v)])
return output
def _relative_position_to_absolute_position_unmasked(x):
"""Converts tensor from relative to aboslute indexing for local attention.
Args:
x: a Tensor of shape [batch (or batch*num_blocks), heads,
length, 2 * length - 1]
Returns:
A Tensor of shape [batch (or batch*num_blocks), heads, length, length]
"""
x_shape = common_layers.shape_list(x)
batch = x_shape[0]
heads = x_shape[1]
length = x_shape[2]
# Concat columns of pad to shift from relative to absolute indexing.
col_pad = tf.zeros((batch, heads, length, 1))
x = tf.concat([x, col_pad], axis=3)
# Concat extra elements so to add up to shape (len+1, 2*len-1).
flat_x = tf.reshape(x, [batch, heads, length * 2 * length])
flat_pad = tf.zeros((batch, heads, length-1))
flat_x_padded = tf.concat([flat_x, flat_pad], axis=2)
# Reshape and slice out the padded elements.
final_x = tf.reshape(flat_x_padded, [batch, heads, length+1, 2*length-1])
final_x = final_x[:, :, :, length-1:]
final_x = final_x[:, :, :length, :]
return final_x
def masked_local_attention_1d(q,
k,
v,
block_length=128,
make_image_summary=False,
dropout_rate=0.,
name=None):
"""Attention to the source position and a neighborhood to the left of it.
The sequence is divided into blocks of length block_length. Attention for a
given query position can only see memory positions less than or equal to the
query position, in the corresponding block and the previous block.
Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
make_image_summary: a boolean, whether to make an attention image summary.
dropout_rate: Dropout rate for attention dropout
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth_v]
"""
with tf.variable_scope(
name, default_name="local_attention_1d", values=[q, k, v]):
batch, heads, length, depth_k = common_layers.shape_list(q)
depth_v = common_layers.shape_list(v)[-1]
if isinstance(block_length, tf.Tensor):
const = contrib.util().constant_value(block_length)
if const is not None:
block_length = int(const)
# If (length < 2 * block_length), then we use only one block.
if isinstance(length, int) and isinstance(block_length, int):
block_length = length if length < block_length * 2 else block_length
else:
block_length = tf.where(
tf.less(length, block_length * 2), length, block_length)
# Pad query, key, value to ensure multiple of block length.
original_length = length
padding_size = tf.mod(-length, block_length)
length += padding_size
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
q = tf.pad(q, padding)
k = tf.pad(k, padding)
v = tf.pad(v, padding)
if isinstance(length, int) and isinstance(block_length, int):
num_blocks = length // block_length
else:
num_blocks = tf.div(length, block_length)
# Compute attention for the first query block.
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_output = dot_product_attention(
first_q,
first_k,
first_v,
attention_bias_lower_triangle(block_length),
dropout_rate=dropout_rate,
make_image_summary=make_image_summary,
name="first_block")
# Compute attention for all subsequent query blocks.
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
local_k = _make_local_block(k, depth_k, batch, heads, num_blocks,
block_length)
local_v = _make_local_block(v, depth_v, batch, heads, num_blocks,
block_length)
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
tail_q = tf.reshape(tail_q,
[batch, heads, num_blocks - 1, block_length, depth_k])
local_length = common_layers.shape_list(local_k)[3]
# make sure source_pos <= target_pos
good_part = common_layers.ones_matrix_band_part(
block_length,
local_length,
-1,
block_length,
out_shape=[1, 1, 1, block_length, local_length])
bias = (1.0 - good_part) * -1e9
# TODO(noam): figure out how to show a summary for the remaining blocks.
# The naive way currently causes errors due to empty tensors.
# output: [batch, heads, num_blocks-1, block_length, depth_v]
tail_output = dot_product_attention(
tail_q,
local_k,
local_v,
bias,
dropout_rate=dropout_rate,
make_image_summary=False,
name="tail_block")
tail_output = tf.reshape(
tail_output, [batch, heads, (num_blocks - 1) * block_length, depth_v])
output = tf.concat([first_output, tail_output], axis=2)
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output = tf.reshape(output, [batch, heads, original_length, depth_v])
return output
def _make_local_block(x, depth, batch, heads, num_blocks, block_length):
"""Helper function to create a local version of the keys or values for 1d."""
prev_block = tf.slice(x, [0, 0, 0, 0, 0],
[-1, -1, num_blocks - 1, -1, -1])
cur_block = tf.slice(x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
local_block = tf.concat([prev_block, cur_block], 3)
return tf.reshape(local_block,
[batch, heads, num_blocks - 1, block_length * 2, depth])
def masked_relative_local_attention_1d(q,
k,
v,
block_length=128,
make_image_summary=False,
dropout_rate=0.,
heads_share_relative_embedding=False,
add_relative_to_values=False,
name=None):
"""Masked local 1d attention with relative positions.
The sequence is divided into blocks of length block_size.
Attention for a given query position can only see memory positions
less than or equal to the query position, in the corresponding block
and the previous block.
If mask_right is True, then a target position cannot see greater source
positions.
Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
make_image_summary: a boolean, whether to make an attention image summary.
dropout_rate: Dropout rate for attention dropout
heads_share_relative_embedding: a boolean for sharing relative embeddings.
add_relative_to_values: a boolean for whether to add relative component to
values.
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth_v]
Raises:
ValueError: wwhen the name for the variable scope is not passed.
"""
if not name:
raise ValueError("Name must be assigned since reuse for variable scope is "
"set to tf.AUTO_REUSE, in order to reuse relative "
"embeddings of keys and values.")
# Reuse flag is set to auto_reuse to reuse relative embeddings of keys and
# values across blocks (first and tail blocks).
with tf.variable_scope(
name, default_name="masked_relative_local_attention_1d",
values=[q, k, v], reuse=tf.AUTO_REUSE):
default_block_length = block_length
batch = common_layers.shape_list(q)[0]
heads = common_layers.shape_list(q)[1]
length = common_layers.shape_list(q)[2]
# If (length < 2 * block_length), then we use only one block.
if isinstance(length, int) and isinstance(block_length, int):
block_length = length if length < block_length * 2 else block_length
else:
block_length = tf.where(
tf.less(length, block_length * 2), length, block_length)
depth_k = common_layers.shape_list(k)[3]
depth_v = common_layers.shape_list(v)[3]
original_length = length
padding_size = tf.mod(-length, block_length)
length += padding_size
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
q = tf.pad(q, padding)
k = tf.pad(k, padding)
v = tf.pad(v, padding)
num_blocks = length // block_length
# compute attention for the first query block.
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
# Relative embeddings will be used later as well.
# TODO(avaswani,annahuang): check why 2*bl was breaking for music
# Needs to be known at static shape inference time, hence cannot be
# 2 * block_length.
rel_embed_length = 4 * default_block_length
# We only multiply with the needed embeddings as we slice them out.
first_rel_embeddings = get_relative_embeddings_left(
rel_embed_length, block_length, depth_k, heads,
heads_share_relative_embedding, "relative_embeddings")
first_rel_logits = matmul_with_relative_keys(
first_q, first_rel_embeddings, heads_share_relative_embedding)
first_logits = tf.matmul(first_q, first_k, transpose_b=True)
first_logits += (
_relative_position_to_absolute_position_masked(first_rel_logits))
# adding a mask
first_logits += (
common_layers.cast_like(attention_bias_lower_triangle(block_length),
first_logits))
first_att = tf.nn.softmax(first_logits,
name="first_attention_weights")
# dropping out the attention links for each of the heads
first_att = common_layers.dropout_with_broadcast_dims(
first_att, 1.0 - dropout_rate,
broadcast_dims=None)
# only call image summary for the first block
if common_layers.should_generate_summaries() and make_image_summary:
attention_image_summary(first_att, None)
first_output = tf.matmul(first_att, first_v)
# compute attention for all subsequent query blocks.
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
local_k = _make_local_block(k, depth_k, batch, heads, num_blocks,
block_length)
local_v = _make_local_block(v, depth_v, batch, heads, num_blocks,
block_length)
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
tail_q = tf.reshape(tail_q,
[batch, heads, num_blocks - 1, block_length, depth_k])
local_length = common_layers.shape_list(local_k)[3]
# collapsing num blocks and batch size so that we can reuse
# functions
def _reshape_for_relative(x):
x_shape = common_layers.shape_list(x)
# [batch, num_blocks, heads, length, depth]
x = tf.transpose(x, [0, 2, 1, 3, 4])
x = tf.reshape(x, [batch*x_shape[2], heads, x_shape[3],
x_shape[4]])
return x
rel_tail_q = _reshape_for_relative(tail_q)
rel_k = _reshape_for_relative(local_k)
rel_v = _reshape_for_relative(local_v)
rel_embeddings = get_relative_embeddings_left(
rel_embed_length, 2 * block_length, depth_k, heads,
heads_share_relative_embedding, "relative_embeddings")
rel_logits = matmul_with_relative_keys(
rel_tail_q, rel_embeddings, heads_share_relative_embedding)
# Computing relative logits separately for the masked and unmasked parts
# because the reshaping logic is different for both
masked_rel_logits = tf.slice(rel_logits, [0, 0, 0, block_length],
[-1, -1, -1, -1])
masked_rel_logits = _relative_position_to_absolute_position_masked(
masked_rel_logits)
unmasked_rel_logits = tf.slice(rel_logits, [0, 0, 0, 0],
[-1, -1, -1, 2*block_length-1])
unmasked_rel_logits = _relative_position_to_absolute_position_unmasked(
unmasked_rel_logits)
all_rel_logits = tf.concat([unmasked_rel_logits, masked_rel_logits],
axis=3)
all_logits = (
tf.matmul(rel_tail_q, rel_k, transpose_b=True) + all_rel_logits)
# make sure source_pos <= target_pos
good_part = common_layers.ones_matrix_band_part(block_length,
local_length,
-1, block_length)
mask = (1.0 - good_part) * -1e9
mask = common_layers.cast_like(mask, all_logits)
all_logits += tf.reshape(mask, [1, 1, block_length, local_length])
weights = tf.nn.softmax(all_logits, name="attention_weights")
# [batch (* num_blocks), heads, query_length (=block_length),
# key_length (=2*block_length)]
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate,
broadcast_dims=None)
output = tf.matmul(weights, rel_v)
if add_relative_to_values:
# Adds the contribution of the weighted relative embeddings to the values.
weights_for_unmasked, weights_for_masked = (
tf.split(weights, 2, axis=3))
rel_weights_unmasked = _absolute_position_to_relative_position_unmasked(
weights_for_unmasked)
rel_weights_masked = _absolute_position_to_relative_position_masked(
weights_for_masked)
value_rel_embeddings_unmasked = get_relative_embeddings_left(
rel_embed_length, 2 * block_length, depth_v,
heads, heads_share_relative_embedding,
"value_relative_embeddings")
# The unmasked part starts with index -1 as opposed 0 has take uptil last.
if heads_share_relative_embedding:
value_rel_embeddings_unmasked = value_rel_embeddings_unmasked[:-1, :]
else:
value_rel_embeddings_unmasked = value_rel_embeddings_unmasked[:, :-1, :]
value_rel_embeddings_masked = get_relative_embeddings_left(
rel_embed_length, block_length, depth_v,
heads, heads_share_relative_embedding,
"value_relative_embeddings")
# [batch (*num_blocks), heads, query length, key length]
rel_weights = tf.concat(
[rel_weights_unmasked, rel_weights_masked], axis=3)
if heads_share_relative_embedding:
value_rel_embeddings_concat_axis = 0
else:
value_rel_embeddings_concat_axis = 1
value_rel_embeddings = tf.concat(
[value_rel_embeddings_unmasked, value_rel_embeddings_masked],
axis=value_rel_embeddings_concat_axis)
output_rel = matmul_with_relative_values(
rel_weights, value_rel_embeddings, heads_share_relative_embedding)
output += output_rel
# bring to [batch, heads, num_blocks-1, block_length, depth]
output = tf.reshape(output,
[batch, num_blocks-1, heads, block_length, depth_v])
output = tf.transpose(output, [0, 2, 1, 3, 4])
output = tf.reshape(
output, [batch, heads, (num_blocks - 1) * block_length, depth_v])
output = tf.concat([first_output, output], axis=2)
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output = tf.reshape(output, [batch, heads, original_length, depth_v])
return output
def matmul_with_relative_values(x, y, heads_share_relative_embedding):
if heads_share_relative_embedding:
ret = tf.einsum("bhlm,md->bhld", x, y)
else:
ret = tf.einsum("bhlm,hmd->bhld", x, y)
return ret
def matmul_with_relative_keys(x, y, heads_share_relative_embedding):
if heads_share_relative_embedding:
ret = tf.einsum("bhld,md->bhlm", x, y)
else:
ret = tf.einsum("bhld,hmd->bhlm", x, y)
return ret
def local_attention_1d(q, k, v, block_length=128, filter_width=100, name=None):
"""Strided block local self-attention.
The sequence is divided into blocks of length block_length. Attention for a
given query position can see all memory positions in the corresponding block
and filter_width many positions to the left and right of the block.
Args:
q: a Tensor with shape [batch, heads, length, depth_k]
k: a Tensor with shape [batch, heads, length, depth_k]
v: a Tensor with shape [batch, heads, length, depth_v]
block_length: an integer
filter_width: an integer indicating how much to look left and right of the
block.
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth_v]
"""
with tf.variable_scope(
name, default_name="local_self_attention_1d", values=[q, k, v]):
# Check that q, k, v have the same shape except in their depth dimension.
q.get_shape()[:-1].assert_is_compatible_with(k.get_shape()[:-1])
q.get_shape()[:-1].assert_is_compatible_with(v.get_shape()[:-1])
batch_size, num_heads, original_length, _ = common_layers.shape_list(q)
# Pad query, key, value to ensure multiple of corresponding lengths.
def pad_to_multiple(x, pad_length):
x_length = common_layers.shape_list(x)[2]
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
def pad_l_and_r(x, pad_length):
return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]])
# Set up query blocks.
# [batch, heads, blocks_q, block_length, depth_k]
q = pad_to_multiple(q, block_length)
q = reshape_by_blocks(q, common_layers.shape_list(q), block_length)
total_query_blocks = common_layers.shape_list(q)[2]
# Set up key and value blocks.
# [batch, heads, blocks_k, block_length, depth_k]
blocks_per_filter_width = filter_width // block_length
remaining_items = filter_width % block_length
k = pad_to_multiple(k, block_length)
v = pad_to_multiple(v, block_length)
k = pad_l_and_r(k, filter_width + block_length - remaining_items)
v = pad_l_and_r(v, filter_width + block_length - remaining_items)
k = reshape_by_blocks(k, common_layers.shape_list(k), block_length)
v = reshape_by_blocks(v, common_layers.shape_list(v), block_length)
total_kv_blocks = common_layers.shape_list(k)[2]
slices = []
# prepare the left-most and right-most partial blocks if needed
if remaining_items:
first_partial_block_k = tf.slice(
k, [0, 0, 0, block_length - remaining_items, 0],
[-1, -1, total_query_blocks, -1, -1])
first_partial_block_v = tf.slice(
v, [0, 0, 0, block_length - remaining_items, 0],
[-1, -1, total_query_blocks, -1, -1])
last_partial_block_k = tf.slice(
k, [0, 0, total_kv_blocks - total_query_blocks, 0, 0],
[-1, -1, -1, remaining_items, -1])
last_partial_block_v = tf.slice(
v, [0, 0, total_kv_blocks - total_query_blocks, 0, 0],
[-1, -1, -1, remaining_items, -1])
slices.append((first_partial_block_k, first_partial_block_v))
slices.append((last_partial_block_k, last_partial_block_v))
# Prepare the rest of the blocks
first_block_index = 1 if remaining_items else 0
attention_blocks = 2 * blocks_per_filter_width + 1
for i in range(first_block_index, attention_blocks + first_block_index):
block_k = tf.slice(k, [0, 0, i, 0, 0],
[-1, -1, total_query_blocks, -1, -1])
block_v = tf.slice(v, [0, 0, i, 0, 0],
[-1, -1, total_query_blocks, -1, -1])
slices.append((block_k, block_v))
# [batch, heads, blocks_q, block_length + 2 * filter_width, depth_k]
k = tf.concat([s[0] for s in slices], axis=3)
v = tf.concat([s[1] for s in slices], axis=3)
attention_bias = tf.expand_dims(embedding_to_padding(k) * -1e9, axis=-2)
depth_v = common_layers.shape_list(v)[-1]
output = dot_product_attention(
q,
k,
v,
attention_bias,
dropout_rate=0.,
name="local_1d",
make_image_summary=False)
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output.set_shape([None if isinstance(dim, tf.Tensor) else dim for dim in
(batch_size, num_heads, original_length, depth_v)])
return output
def reshape_by_blocks(x, x_shape, memory_block_size):
"""Reshapes input by splitting its length over blocks of memory_block_size.
Args:
x: a Tensor with shape [batch, heads, length, depth]
x_shape: tf.TensorShape of x.
memory_block_size: Integer which divides length.
Returns:
Tensor with shape
[batch, heads, length // memory_block_size, memory_block_size, depth].
"""
x = tf.reshape(x, [
x_shape[0], x_shape[1], x_shape[2] // memory_block_size,
memory_block_size, x_shape[3]
])
return x
def dilated_self_attention_1d(q,
k,
v,
query_block_size=128,
memory_block_size=128,
gap_size=2,
num_memory_blocks=2,
name=None):
"""Dilated self-attention.
Args:
q: a Tensor with shape [batch, heads, length, depth]
k: a Tensor with shape [batch, heads, length, depth]
v: a Tensor with shape [batch, heads, length, depth]
query_block_size: an integer indicating size of query block
memory_block_size: an integer indicating the size of a memory block.
gap_size: an integer indicating the gap size
num_memory_blocks: how many memory blocks to look at to the left and right.
Each will be separated by gap_size.
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth]
"""
with tf.variable_scope(
name, default_name="dilated_self_attention_1d", values=[q, k, v]):
v_list_shape = v.get_shape().as_list()
assert v_list_shape == k.shape.as_list(), "K and V depths must be equal"
v_shape = common_layers.shape_list(v)
depth_v = v_shape[3]
batch_size = v_shape[0]
num_heads = v_shape[1]
original_length = common_layers.shape_list(q)[2]
# Pad query, key, value to ensure multiple of corresponding lengths.
def pad_to_multiple(x, pad_length):
x_length = common_layers.shape_list(x)[2]
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
def pad_l_and_r(x, pad_length):
return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]])
q = pad_to_multiple(q, query_block_size)
v = pad_to_multiple(v, query_block_size)
k = pad_to_multiple(k, query_block_size)
# Set up query blocks.
new_q_shape = common_layers.shape_list(q)
q = reshape_by_blocks(q, new_q_shape, query_block_size)
self_k_part = reshape_by_blocks(k, new_q_shape, query_block_size)
self_v_part = reshape_by_blocks(v, new_q_shape, query_block_size)
# Set up key and value windows.
k_v_padding = (gap_size + memory_block_size) * num_memory_blocks
k = pad_l_and_r(k, k_v_padding)
v = pad_l_and_r(v, k_v_padding)
# Get gather indices.
index_length = (new_q_shape[2] - query_block_size + memory_block_size)
indices = tf.range(0, index_length, delta=1, name="index_range")
indices = tf.reshape(indices, [1, -1, 1]) # [1, length, 1] for convs
kernel = tf.expand_dims(tf.eye(memory_block_size), axis=1)
gather_indices = tf.nn.conv1d(
tf.cast(indices, tf.float32),
kernel,
query_block_size,
padding="VALID",
name="gather_conv")
gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0)
# Get left and right memory blocks for each query.
# [length, batch, heads, dim]
k_t = tf.transpose(k, [2, 0, 1, 3])
v_t = tf.transpose(v, [2, 0, 1, 3])
left_k = gather_dilated_memory_blocks(
k_t[:-k_v_padding, :, :, :], num_memory_blocks, gap_size,
query_block_size, memory_block_size, gather_indices)
left_v = gather_dilated_memory_blocks(
v_t[:-k_v_padding, :, :, :], num_memory_blocks, gap_size,
query_block_size, memory_block_size, gather_indices)
right_k = gather_dilated_memory_blocks(
k_t[k_v_padding:, :, :, :],
num_memory_blocks,
gap_size,
query_block_size,
memory_block_size,
gather_indices,
direction="right")
right_v = gather_dilated_memory_blocks(
v_t[k_v_padding:, :, :, :],
num_memory_blocks,
gap_size,
query_block_size,
memory_block_size,
gather_indices,
direction="right")
k_windows = tf.concat([left_k, self_k_part, right_k], axis=3)
v_windows = tf.concat([left_v, self_v_part, right_v], axis=3)
attention_bias = tf.expand_dims(
embedding_to_padding(k_windows) * -1e9, axis=-2)
output = dot_product_attention(
q,
k_windows,
v_windows,
attention_bias,
dropout_rate=0.,
name="dilated_1d",
make_image_summary=False)
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output.set_shape(v_list_shape)
return output
def gather_dilated_memory_blocks(x,
num_memory_blocks,
gap_size,
query_block_size,
memory_block_size,
gather_indices,
direction="left"):
"""Gathers blocks with gaps in between.
Args:
x: Tensor of shape [length, batch, heads, depth]
num_memory_blocks: how many memory blocks to look in "direction". Each will
be separated by gap_size.
gap_size: an integer indicating the gap size
query_block_size: an integer indicating size of query block
memory_block_size: an integer indicating the size of a memory block.
gather_indices: The indices to gather from.
direction: left or right
Returns:
Tensor of shape [batch, heads, blocks, block_length, depth]
"""
gathered_blocks = []
# gathering memory blocks
for block_id in range(num_memory_blocks):
block_end_index = -(query_block_size + gap_size *
(block_id + 1) + memory_block_size * block_id)
block_start_index = (
(memory_block_size + gap_size) * (num_memory_blocks - (block_id + 1)))
if direction != "left":
[block_end_index,
block_start_index] = [-block_start_index, -block_end_index]
if block_end_index == 0:
x_block = x[block_start_index:]
else:
x_block = x[block_start_index:block_end_index]
def gather_dilated_1d_blocks(x, gather_indices):
x_new = tf.gather(x, gather_indices)
# [batch, heads, blocks, block_length, dim]
return tf.transpose(x_new, [2, 3, 0, 1, 4])
gathered_blocks.append(gather_dilated_1d_blocks(x_block, gather_indices))
return tf.concat(gathered_blocks, 3)
def masked_dilated_self_attention_1d(q,
k,
v,
query_block_size=64,
memory_block_size=64,
gap_size=2,
num_memory_blocks=2,
name=None):
"""Dilated self-attention. TODO(avaswani): Try it and write a paper on it.
Args:
q: a Tensor with shape [batch, heads, length, depth]
k: a Tensor with shape [batch, heads, length, depth]
v: a Tensor with shape [batch, heads, length, depth]
query_block_size: an integer
memory_block_size: an integer indicating how much to look left.
gap_size: an integer indicating the gap size
num_memory_blocks: how many memory blocks to look at to the left. Each will
be separated by gap_size.
name: an optional string
Returns:
a Tensor of shape [batch, heads, length, depth]
"""
with tf.variable_scope(
name, default_name="masked_dilated_self_attention_1d", values=[q, k, v]):
v_list_shape = v.get_shape().as_list()
assert v_list_shape == k.shape.as_list(), "K and V depths must be equal"
v_shape = common_layers.shape_list(v)
depth_v = v_shape[3]
batch_size = v_shape[0]
num_heads = v_shape[1]
original_length = common_layers.shape_list(q)[2]
# Pad query, key, value to ensure multiple of corresponding lengths.
def pad_to_multiple(x, pad_length):
x_length = common_layers.shape_list(x)[2]
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
def pad_l(x, left_pad_length):
return tf.pad(x, [[0, 0], [0, 0], [left_pad_length, 0], [0, 0]])
q = pad_to_multiple(q, query_block_size)
v = pad_to_multiple(v, query_block_size)
k = pad_to_multiple(k, query_block_size)
# Set up query blocks.
new_q_shape = common_layers.shape_list(q)
q = reshape_by_blocks(q, new_q_shape, query_block_size)
# Set up key and value windows.
self_k_part = reshape_by_blocks(k, new_q_shape, query_block_size)
self_v_part = reshape_by_blocks(v, new_q_shape, query_block_size)
k_v_padding = (gap_size + memory_block_size) * num_memory_blocks
k = pad_l(k, k_v_padding)
v = pad_l(v, k_v_padding)
# Get gather indices.
index_length = (new_q_shape[2] - query_block_size + memory_block_size)
indices = tf.range(0, index_length, delta=1, name="index_range")
indices = tf.reshape(indices, [1, -1, 1]) # [1, length, 1] for convs
kernel = tf.expand_dims(tf.eye(memory_block_size), axis=1)
gather_indices = tf.nn.conv1d(
tf.cast(indices, tf.float32),
kernel,
query_block_size,
padding="VALID",
name="gather_conv")
gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0)
# Get left and right memory blocks for each query.
# [length, batch, heads, dim]
k_t = tf.transpose(k, [2, 0, 1, 3])
v_t = tf.transpose(v, [2, 0, 1, 3])
k_unmasked_windows = gather_dilated_memory_blocks(
k_t, num_memory_blocks, gap_size, query_block_size, memory_block_size,
gather_indices)
v_unmasked_windows = gather_dilated_memory_blocks(
v_t, num_memory_blocks, gap_size, query_block_size, memory_block_size,
gather_indices)
# Combine memory windows.
block_q_shape = common_layers.shape_list(q)
masked_attention_bias = tf.tile(
tf.expand_dims(attention_bias_lower_triangle(query_block_size), axis=0),
[block_q_shape[0], block_q_shape[1], block_q_shape[2], 1, 1])
padding_attention_bias = tf.expand_dims(
embedding_to_padding(k_unmasked_windows) * -1e9, axis=-2)
padding_attention_bias = tf.tile(padding_attention_bias,
[1, 1, 1, query_block_size, 1])
attention_bias = tf.concat(
[masked_attention_bias, padding_attention_bias], axis=-1)
# combine memory windows
k_windows = tf.concat([self_k_part, k_unmasked_windows], 3)
v_windows = tf.concat([self_v_part, v_unmasked_windows], 3)
output = dot_product_attention(
q,
k_windows,
v_windows,
attention_bias,
dropout_rate=0.,
name="dilated_1d",
make_image_summary=False)
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
output.set_shape(v_list_shape)
return output
def local_attention_2d(q,
k,
v,
query_shape=(8, 16),
memory_flange=(8, 16),
name=None):
"""Strided block local self-attention.
The 2-D sequence is divided into 2-D blocks of shape query_shape. Attention
for a given query position can only see memory positions less than or equal to
the query position. The memory positions are the corresponding block with
memory_flange many positions to add to the height and width of the block
(namely, left, top, and right).
Args:
q: a Tensor with shape [batch, heads, h, w, depth_k]
k: a Tensor with shape [batch, heads, h, w, depth_k]
v: a Tensor with shape [batch, heads, h, w, depth_v]. In the current
implementation, depth_v must be equal to depth_k.
query_shape: an tuple indicating the height and width of each query block.
memory_flange: an integer indicating how much to look in height and width
from each query block.
name: an optional string
Returns:
a Tensor of shape [batch, heads, h, w, depth_v]
"""
with tf.variable_scope(
name, default_name="local_self_attention_2d", values=[q, k, v]):
v_shape = common_layers.shape_list(v)
# Pad query, key, value to ensure multiple of corresponding lengths.
q = pad_to_multiple_2d(q, query_shape)
k = pad_to_multiple_2d(k, query_shape)
v = pad_to_multiple_2d(v, query_shape)
paddings = [[0, 0], [0, 0], [memory_flange[0], memory_flange[1]],
[memory_flange[0], memory_flange[1]], [0, 0]]
k = tf.pad(k, paddings)
v = tf.pad(v, paddings)
# Set up query blocks.
q_indices = gather_indices_2d(q, query_shape, query_shape)
q_new = gather_blocks_2d(q, q_indices)
# Set up key and value blocks.
memory_shape = (query_shape[0] + 2 * memory_flange[0],
query_shape[1] + 2 * memory_flange[1])
k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape)
k_new = gather_blocks_2d(k, k_and_v_indices)
v_new = gather_blocks_2d(v, k_and_v_indices)
attention_bias = tf.expand_dims(
tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2)
output = dot_product_attention(
q_new,
k_new,
v_new,
attention_bias,
dropout_rate=0.,
name="local_2d",
make_image_summary=False)
# Put representations back into original shapes.
padded_q_shape = common_layers.shape_list(q)
output = scatter_blocks_2d(output, q_indices, padded_q_shape)
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0, 0],
[-1, -1, v_shape[2], v_shape[3], -1])
return output
def pad_to_multiple_2d(x, block_shape):
"""Making sure x is a multiple of shape.
Args:
x: a [batch, heads, h, w, depth] or [batch, h, w, depth] tensor
block_shape: a 2-d list of integer shapes
Returns:
padded_x: a [batch, heads, h, w, depth] or [batch, h, w, depth] tensor
"""
old_shape = x.get_shape().dims
last = old_shape[-1]
if len(old_shape) == 4:
height_padding = -common_layers.shape_list(x)[1] % block_shape[0]
width_padding = -common_layers.shape_list(x)[2] % block_shape[1]
paddings = [[0, 0], [0, height_padding], [0, width_padding], [0, 0]]
elif len(old_shape) == 5:
height_padding = -common_layers.shape_list(x)[2] % block_shape[0]
width_padding = -common_layers.shape_list(x)[3] % block_shape[1]
paddings = [[0, 0], [0, 0], [0, height_padding], [0, width_padding], [0, 0]]
padded_x = tf.pad(x, paddings)
padded_shape = padded_x.get_shape().as_list()
padded_shape = padded_shape[:-1] + [last]
padded_x.set_shape(padded_shape)
return padded_x
def reshape_range(tensor, i, j, shape):
"""Reshapes a tensor between dimensions i and j."""
t_shape = common_layers.shape_list(tensor)
target_shape = t_shape[:i] + shape + t_shape[j:]
return tf.reshape(tensor, target_shape)
def gather_blocks_2d(x, indices):
"""Gathers flattened blocks from x."""
x_shape = common_layers.shape_list(x)
x = reshape_range(x, 2, 4, [tf.reduce_prod(x_shape[2:4])])
# [length, batch, heads, dim]
x_t = tf.transpose(x, [2, 0, 1, 3])
x_new = tf.gather(x_t, indices)
# returns [batch, heads, num_blocks, block_length ** 2, dim]
return tf.transpose(x_new, [2, 3, 0, 1, 4])
def scatter_blocks_2d(x, indices, shape):
"""scatters blocks from x into shape with indices."""
x_shape = common_layers.shape_list(x)
# [length, batch, heads, dim]
x_t = tf.transpose(
tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), [2, 0, 1, 3])
x_t_shape = common_layers.shape_list(x_t)
indices = tf.reshape(indices, [-1, 1])
scattered_x = tf.scatter_nd(indices, x_t, x_t_shape)
scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3])
return tf.reshape(scattered_x, shape)
def gather_indices_2d(x, block_shape, block_stride):
"""Getting gather indices."""
# making an identity matrix kernel
kernel = tf.eye(block_shape[0] * block_shape[1])
kernel = reshape_range(kernel, 0, 1, [block_shape[0], block_shape[1], 1])
# making indices [1, h, w, 1] to appy convs
x_shape = common_layers.shape_list(x)
indices = tf.range(x_shape[2] * x_shape[3])
indices = tf.reshape(indices, [1, x_shape[2], x_shape[3], 1])
indices = tf.nn.conv2d(
tf.cast(indices, tf.float32),
kernel,
strides=[1, block_stride[0], block_stride[1], 1],
padding="VALID")
# making indices [num_blocks, dim] to gather
dims = common_layers.shape_list(indices)[:3]
if all([isinstance(dim, int) for dim in dims]):
num_blocks = functools.reduce(operator.mul, dims, 1)
else:
num_blocks = tf.reduce_prod(dims)
indices = tf.reshape(indices, [num_blocks, -1])
return tf.cast(indices, tf.int32)
def make_2d_block_raster_mask(query_shape, memory_flange):
"""Creates a mask for 2d block raster scan.
The query mask can look to the left, top left, top, and top right, but
not to the right. Inside the query, we have the standard raster scan
masking.
Args:
query_shape: A tuple of ints (query_height, query_width)
memory_flange: A tuple of ints
(memory_flange_height, memory_flange_width)
Returns:
A tensor of shape query_size, memory_size
"""
# mask inside the query block
query_triangle = common_layers.ones_matrix_band_part(
np.prod(query_shape), np.prod(query_shape), -1, 0)
split_query_masks = tf.split(query_triangle, query_shape[0], axis=1)
# adding mask for left and right
mask_pieces = [
tf.concat( # pylint: disable=g-complex-comprehension
[tf.ones([np.prod(query_shape), memory_flange[1]]),
split_query_masks[i],
tf.zeros([np.prod(query_shape), memory_flange[1]])],
axis=1) for i in range(query_shape[0])
]
# adding mask for top
final_mask = tf.concat(
[
tf.ones([
np.prod(query_shape),
(query_shape[1] + 2 * memory_flange[1]) * memory_flange[0]
]),
tf.concat(mask_pieces, axis=1)
],
axis=1)
# 0.0 is visible location, 1.0 is masked.
return 1. - final_mask
def get_memory_region(x, query_block_shape, memory_flange, q_indices):
"""Get the memory regions that surround a 2d query.
The memory regions will be the left and top right.
Args:
x: A tensor with shape [batch, heads, height, width, depth]
query_block_shape: a 2-d tuple of integers
memory_flange: a 2-d tuple of integers
q_indices: a tensor of indices for each of the center blocks.
[num_blocks, block_length]
Returns:
x_flange: A tensor of shape [batch, heads, #blocks, block_length, depth]
"""
# Padding x to be multiple of query_shape and then
# extracting the memory blocks from the same regions as the query blocks
x_query_padded = pad_to_multiple_2d(x, query_block_shape)
x_center = gather_blocks_2d(x_query_padded, q_indices)
# Then padding the flange region
paddings = [[0, 0], [0, 0], [memory_flange[0], 0],
[memory_flange[1], memory_flange[1]], [0, 0]]
x_memory_padded = tf.pad(x_query_padded, paddings)
left_x = None
top_x = None
# Extracting the memory regions around the query block. left_x_region extends
# to the left and the top_x_region is the combination of top left, top, and
# top right of the query block
# if no left region
if memory_flange[1] > 0:
left_x_region = x_memory_padded[:, :, memory_flange[
0]:, :-(query_block_shape[1] + memory_flange[1]), :]
left_memory_shape = (query_block_shape[0], memory_flange[1])
left_indices = gather_indices_2d(left_x_region, left_memory_shape,
query_block_shape)
left_x = gather_blocks_2d(left_x_region, left_indices)
# if no top region
if memory_flange[0] > 0:
top_x_region = x_memory_padded[:, :, :-query_block_shape[0], :, :]
top_memory_shape = (memory_flange[0],
query_block_shape[1] + 2 * memory_flange[1])
top_indices = gather_indices_2d(top_x_region, top_memory_shape,
query_block_shape)
top_x = gather_blocks_2d(top_x_region, top_indices)
x_flange = None
if top_x is not None and left_x is not None:
x_flange = tf.concat([top_x, left_x], axis=3)
else:
x_flange = top_x if top_x is not None else left_x
return x_flange, x_center
def get_shifted_center_blocks(x, indices):
"""Get right shifted blocks for masked local attention 2d.
Args:
x: A tensor with shape [batch, heads, height, width, depth]
indices: The indices to gather blocks
Returns:
x_shifted: a tensor of extracted blocks, each block right shifted along
length.
"""
center_x = gather_blocks_2d(x, indices)
# Shift right along the length dimension
def shift_right_2d_blocks(x):
"""Shift the second to last dimension of x right by one."""
shifted_targets = (
tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0], [0, 0]])[:, :, :, :-1, :])
return shifted_targets
x_shifted = shift_right_2d_blocks(center_x)
return x_shifted
def right_shift_blockwise(x, query_shape, name=None):
"""Right shifts once in every block.
Args:
x: a tensor of shape [batch, height, width, depth]
query_shape: A 2d tuple of ints
name: a string
Returns:
output: a tensor of the same shape as x
"""
with tf.variable_scope(
name, default_name="right_shift_blockwise", values=[x]):
x_list_shape = x.get_shape().as_list()
x_shape = common_layers.shape_list(x)
# Add a dummy dimension for heads.
x = tf.expand_dims(x, axis=1)
x = pad_to_multiple_2d(x, query_shape)
padded_x_shape = common_layers.shape_list(x)
# Set up q blocks.
x_indices = gather_indices_2d(x, query_shape, query_shape)
x_new = get_shifted_center_blocks(x, x_indices)
# Put representations back into original shapes.
output = scatter_blocks_2d(x_new, x_indices, padded_x_shape)
# Remove the dummy head dimension.
output = tf.squeeze(output, axis=1)
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0], [-1, x_shape[1], x_shape[2], -1])
output.set_shape(x_list_shape)
return output
def right_shift_blockwise_nd(x, block_shape):
"""Right shift once in every block.
Args:
x: a [batch, d1, d2, ..., dn, depth] tensor
block_shape: a tuple (q1, q2, ..., qn) representing the block shape
Returns:
a [batch, d1, d2, ..., dn, depth] tensor, right shifted.
"""
blocked_x = break_into_blocks_nd(x, block_shape)
blocked_x_shape = common_layers.shape_list(blocked_x)
blocked_x = tf.reshape(blocked_x,
[blocked_x_shape[0], -1, blocked_x_shape[-1]])
padded_x = tf.pad(blocked_x, [[0, 0], [1, 0], [0, 0]])
x = tf.slice(padded_x, [0, 0, 0],
[-1, np.prod(blocked_x_shape[1:-1], dtype=np.int32), -1])
x = tf.reshape(x, blocked_x_shape)
return put_back_blocks_nd(x, block_shape)
def masked_local_attention_2d(q,
k,
v,
query_shape=(8, 16),
memory_flange=(8, 16),
name=None):
"""Strided block local self-attention.
Each position in a query block can attend to all the generated queries in
the query block, which are generated in raster scan, and positions that are
generated to the left and top. The shapes are specified by query shape and
memory flange. Note that if you're using this function, you do not need to
right shift. Right shifting happens inside this function separately for each
block.
Args:
q: a Tensor with shape [batch, heads, h, w, depth_k]
k: a Tensor with shape [batch, heads, h, w, depth_k]
v: a Tensor with shape [batch, heads, h, w, depth_v]. In the current
implementation, depth_v must be equal to depth_k.
query_shape: an tuple indicating the height and width of each query block.
query_shape = block_shape
memory_flange: an integer indicating how much to look in height and width
from each query block.
memory shape = query_shape + (block_flange[0], 2*block_flange[1])
name: an optional string
Returns:
a Tensor of shape [batch, heads, h, w, depth_v]
"""
with tf.variable_scope(
name, default_name="local_masked_self_attention_2d", values=[q, k, v]):
v_shape = common_layers.shape_list(v)
# Pad query to ensure multiple of corresponding lengths.
q = pad_to_multiple_2d(q, query_shape)
# Set up query blocks.
q_indices = gather_indices_2d(q, query_shape, query_shape)
q_new = gather_blocks_2d(q, q_indices)
# Set up key and value blocks.
k_flange, k_center = get_memory_region(k, query_shape, memory_flange,
q_indices)
v_flange, v_center = get_memory_region(v, query_shape, memory_flange,
q_indices)
if k_flange is not None:
k_new = tf.concat([k_flange, k_center], axis=3)
v_new = tf.concat([v_flange, v_center], axis=3)
else:
k_new = k_center
v_new = v_center
# Set up the masks.
query_elements = np.prod(query_shape)
padding_mask = None
if k_flange is not None:
padding_mask = tf.expand_dims(
embedding_to_padding(k_flange) * -1e9, axis=-2)
padding_mask = tf.tile(padding_mask, [1, 1, 1, query_elements, 1])
center_attention_bias = attention_bias_lower_triangle(
np.prod(query_elements))
center_attention_bias = tf.reshape(
center_attention_bias, [1, 1, 1, query_elements, query_elements])
v_center_shape = common_layers.shape_list(v_center)
center_attention_bias = tf.tile(
center_attention_bias,
[v_center_shape[0], v_center_shape[1], v_center_shape[2], 1, 1])
if padding_mask is not None:
# Combine the mask for padding and visible region.
attention_bias = tf.concat([padding_mask, center_attention_bias], axis=4)
else:
attention_bias = center_attention_bias
output = dot_product_attention(
q_new,
k_new,
v_new,
attention_bias,
dropout_rate=0.,
name="masked_local_2d",
make_image_summary=False)
# Put representations back into original shapes.
padded_q_shape = common_layers.shape_list(q)
output = scatter_blocks_2d(output, q_indices, padded_q_shape)
# Remove the padding if introduced.
output = tf.slice(output, [0, 0, 0, 0, 0],
[-1, -1, v_shape[2], v_shape[3], -1])
return output
def masked_local_attention_nd(q,
k,
v,
query_shape,
memory_flange,
decode_step=None,
name=None):
"""Masked local attention nd.
Each position in q can attend to positions in memory that are positioned less
than or equal to query position according to raster scan ordering and are in
the same memory block. A memory block is n-dimensional and each dimension 'i'
is of size q[i] + 2 * m[i] except for the first dimension which is of size
q[0] + m[0]. NOTE: This computation assumes memory_flange is divisible by
query_shape in every dimension.
Args:
q: a [batch, heads, d1, d2, ..., dn, depth_k] tensor or a [batch, heads, 1,
1, ..., 1, depth_k] tensor in decoding mode.
k: a [batch, heads, d1, d2, ..., dn, depth_k] tensor
v: a [batch, heads, d1, d2, ..., dn, depth_v] tensor
query_shape: a tuple (q1, q2, ..., qn) indicating the shape of query blocks.
memory_flange: a tuple (m1, m2, ..., mn) indicating the number of extra
positions in the attention memory. memory_shape=[q1 + m1, d2 + 2 * m2,
..., dn + 2 * mn]
decode_step: an integer in fast decoding mode.
name: an optional string
Returns:
a [batch, head, d1, d2, ..., dn, depth_v] tensor or
[batch, head, 1, 1, ..., 1, depth_v] if decode_step is not None.
"""
assert all([m % b == 0 for m, b in zip(memory_flange, query_shape)])
with tf.variable_scope(
name, default_name="masked_local_attention_nd", values=[q, k, v]):
# This computation only applies to self attention, so assert q, k and v have
# the same dimensions.
if decode_step is None:
q.get_shape().assert_is_compatible_with(k.get_shape())
q.get_shape()[:-1].assert_is_compatible_with(v.get_shape()[:-1])
else:
k.get_shape().assert_is_compatible_with(v.get_shape())
# move heads to batch dimension. This is needed to reduce number of
# dimensions as much as possible, since most ops support only up to 7
# dimensions.
q_shape = common_layers.shape_list(q)
k_shape = common_layers.shape_list(k)
v_shape = common_layers.shape_list(v)
q = tf.reshape(q, [-1] + q_shape[2:])
k = tf.reshape(k, [-1] + k_shape[2:])
v = tf.reshape(v, [-1] + v_shape[2:])
# Pad query, key, value to ensure multiple of corresponding lengths.
if decode_step is None:
# don't pad query in fast decoding mode. We only need to calculate self
# attention for one position.
q = pad_to_multiple_nd(q, query_shape)
k = pad_to_multiple_nd(k, query_shape)
v = pad_to_multiple_nd(v, query_shape)
# extract query and memory blocks
if decode_step is None:
q = break_into_blocks_nd(q, query_shape)
else:
# in fast decoding, q has 1 block with 1 item in it
# q shape will be [batch] + [1] * n + [1, depth] which is equivalent of
# [batch, b1, b2, ..., bn, items_in_block, depth] where there is 1 block
# and 1 item in that block
q = tf.reshape(q, [-1] + [1] * (len(q_shape) - 3) + [q_shape[-1]])
k = break_into_memory_blocks_nd(k, query_shape, memory_flange, masked=True)
v = break_into_memory_blocks_nd(v, query_shape, memory_flange, masked=True)
# extract just one block of k and v in fast decoding mode.
if decode_step is not None:
k = select_block_for_decode_step(k, decode_step, query_shape)
v = select_block_for_decode_step(v, decode_step, query_shape)
# flatten q, k and v to [batch, num_blocks, items_in_block, depth]
q, blocks_per_dim = flatten_blocks_nd(q)
k, _ = flatten_blocks_nd(k)
v, _ = flatten_blocks_nd(v)
# make attention bias for causal attention.
causal_attn_bias = causal_attention_bias_nd(
query_shape, memory_flange, decode_step=decode_step)
padding_attn_bias = tf.expand_dims(
embedding_to_padding(v[:1, :, :, :]) * -1e9, axis=-2)
if decode_step is None:
num_blocks = common_layers.shape_list(v)[1]
causal_attn_bias = tf.tile(causal_attn_bias, [1, num_blocks, 1, 1])
padding_attn_bias = tf.tile(
padding_attn_bias,
[1, 1, np.prod(query_shape, dtype=np.int32), 1])
attn_bias = tf.minimum(causal_attn_bias, padding_attn_bias)
# Calculate dot product attention
output = dot_product_attention(
q,
k,
v,
attn_bias,
dropout_rate=0.,
name=name or "masked_local_nd",
make_image_summary=False)
# restructure the output from blocks ordering to the original ordering
output = unflatten_blocks_nd(output, blocks_per_dim)
if decode_step is None:
# In fast decoding, output only contains one element, this is not needed.
output = put_back_blocks_nd(output, query_shape)
# bring back the heads dimension
output_shape = common_layers.shape_list(output)
output = tf.reshape(output, q_shape[:2] + output_shape[1:])
if decode_step is None:
# No padding is introduced in fast decoding, no need to do this.
output_shape = common_layers.shape_list(output)
output = tf.slice(output, [0] * len(output_shape),
[-1, -1] + q_shape[2:-1] + [-1])
return output
def select_block_for_decode_step(blocked_x, decode_step, query_shape):
"""Selects one block from `x` that contains position `decode_step`.
NOTE: This method only works for blocked inputs. It selects one block around
`decode_step` position in blocked raster scan order.
Args:
blocked_x: a [batch, blocks_per_d1, ..., blocks_per_dn, b1 * ...* bn, depth]
tensor
decode_step: an integer
query_shape: a tuple (q1, q2, ..., qn) representing query shape
Returns:
a [batch, [1] * n, b1 * ... * bn, depth] tensor
"""
blocked_x_shape = common_layers.shape_list(blocked_x)
# calculate the shape of the normal x
x_shape = [b * q for b, q in zip(blocked_x_shape[1:-2], query_shape)]
# Get the position of `decode_step` element in the unblocked x.
index = decode_step_to_index(decode_step, query_shape, x_shape)
# Convert it to the blocked positions.
blocked_index = [i // q for i, q in zip(index, query_shape)]
# TPU needs size to be non negative for the case when begin is not
# compile-time constants.
return tf.slice(blocked_x, [0] + blocked_index + [0, 0],
[blocked_x_shape[0]] + [1] * len(blocked_index) +
blocked_x_shape[-2:])
def flatten_blocks_nd(x):
"""Flattens blocks of the input tensor.
Args:
x: a [batch, b1, ..., bn, items_in_block, depth] tensor
Returns:
a flattened tensor of shape [batch, b1 * ...* bm, items_in_block, depth]
a list of [b1, ..., bn] which is used for unflattening.
"""
x_shape = common_layers.shape_list(x)
num_blocks = np.prod(x_shape[1:-2], dtype=np.int32)
return tf.reshape(x, [-1, num_blocks] + x_shape[-2:]), x_shape[1:-2]
def unflatten_blocks_nd(x, blocks_per_dimension):
"""Converts a flattened tensor into a normal blocked tensor.
Args:
x: a [batch, d1 * ... dn, items_in_block, depth] tensor
blocks_per_dimension: a n-d list of integers for number of blocks in each
dimension.
Returns:
a [batch, d1, d2, ..., dn, items_in_block, depth] tensor
"""
x_shape = common_layers.shape_list(x)
assert x_shape[1] == np.prod(blocks_per_dimension, dtype=np.int32)
return tf.reshape(x, [-1] + list(blocks_per_dimension) + x_shape[-2:])
def break_into_memory_blocks_nd(x, query_shape, memory_flange, masked=False):
"""Break a tensor into memory blocks around query blocks.
This requires memory_flange to be divisible by query_shape in every dimension.
Args:
x: a [batch, d1, d2, ..., dn, depth] tensor
query_shape: a n-d list of integers representing query shape
memory_flange: an n-d list of integers representing memory flange.
masked: a boolean for masked vs unmasked attention.
Returns:
a [batch, blocks_per_d1, ..., blocks_per_dn, b1 * ...* bn, depth] where bi
is the memory block size in dimension i which is equal to q[i] + 2m[i] or
q[i] + m[i] if masked attention and i = 1.
"""
assert all([m % b == 0 for b, m in zip(query_shape, memory_flange)])
original_x_shape = common_layers.shape_list(x)
# calculate the total number of query blocks in each dimension
blocks_in_memory_flange = [m // b for b, m in zip(query_shape, memory_flange)]
num_query_blocks = [
l // q for l, q in zip(original_x_shape[1:-1], query_shape)
]
# pad x to have enough items on the corners to form the memory blocks.
if masked:
# Only pad the beginning of first dimension in masked mode.
x = tf.pad(x, [[0, 0], [memory_flange[0], 0]] +
[[p, p] for p in memory_flange[1:]] + [[0, 0]])
else:
x = tf.pad(x, [[0, 0]] + [[p, p] for p in memory_flange] + [[0, 0]])
query_blocks = break_into_blocks_nd(x, query_shape)
# stitch query blocks together to form memory blocks of the desired size.
start_indices_per_dimension = []
for dimension, blocks in enumerate(blocks_in_memory_flange):
if masked and dimension == 0:
# num blocks for first dimension in masked mode is blocks + 1
size = blocks + 1
else:
size = 2 * blocks + 1
start_indices_per_dimension.append(range(size))
slices = []
for start_indices in itertools.product(*start_indices_per_dimension):
start = [0] + list(start_indices) + [0, 0]
size = [-1] + num_query_blocks + [-1, -1]
s = tf.slice(query_blocks, start, size)
slices.append(s)
# concat slices in their query block dimension to form the full memory blocks
return tf.concat(slices, axis=-2)
def break_into_blocks_nd(x, block_shape):
"""Break input tensor into blocks of `block_shape`.
Args:
x: a [batch, d1, d2, ..., dn, depth] tensor
block_shape: a n-d list of integers representing block shape
Returns:
a [batch, d1//block1, ..., dn//blockn, block1 *... * blockn, depth] tensor
"""
x_shape = common_layers.shape_list(x)
assert all([l % b == 0 for l, b in zip(x_shape[1:], block_shape)])
blocks_per_dimension = [l // b for l, b in zip(x_shape[1:], block_shape)]
# reshape to [-1, d1 // block1, block1, ..., dn // blockn, blockn, depth]
reshape_to = list(
itertools.chain.from_iterable(zip(blocks_per_dimension, block_shape)))
x = tf.reshape(x, [-1] + reshape_to + x_shape[-1:])
# transpose dimensions to bring the n-d blocks in consecutive dimensions.
block_dimensions_index = [2 * (i + 1) for i in range(len(block_shape))]
x = tf.transpose(x, [0] + [i - 1 for i in block_dimensions_index] +
block_dimensions_index + [2 * len(block_shape) + 1])
return tf.reshape(x, [-1] + blocks_per_dimension +
[np.prod(block_shape, dtype=np.int32)] + x_shape[-1:])
def put_back_blocks_nd(x, block_shape):
"""Restructure input tensor from blocks to normal ordering.
Args:
x: a [batch, b1, ..., bn, items_in_block, depth] tensor
block_shape: a n-d list of integers representing block shape.
Returns:
a [batch, d1, ..., dn, depth] where blocks are put back to form the
original tensor.
"""
x_shape = common_layers.shape_list(x)
assert x_shape[-2] == np.prod(block_shape)
x = tf.reshape(x, x_shape[:-2] + list(block_shape) + x_shape[-1:])
block_dimension_index = [i + 1 for i in range(len(block_shape))]
block_shape_index = [b + len(block_shape) for b in block_dimension_index]
interleaved_dimensions = list(
itertools.chain.from_iterable(
zip(block_dimension_index, block_shape_index)))
x = tf.transpose(x, [0] + interleaved_dimensions + [2 * len(block_shape) + 1])
x_shape = common_layers.shape_list(x)
x = tf.reshape(x, [-1] + [
x_shape[2 * i + 1] * x_shape[2 * i + 2] for i in range(len(block_shape))
] + x_shape[-1:])
return x
def pad_to_multiple_nd(x, block_shape):
"""Making sure x is a multiple of shape.
Args:
x: a [batch, d1, d2, ..., dn, depth] tensor
block_shape: a n-d list of integers representing block shape
Returns:
padded x where each dimension is a multiple of corresponding block length.
"""
shape = common_layers.shape_list(x)
paddings = [-l % b for l, b in zip(shape[1:-1], block_shape)]
return tf.pad(x, [[0, 0]] + [[0, p] for p in paddings] + [[0, 0]])
def causal_attention_bias_nd(query_shape, memory_flange, decode_step=None):
"""Creates causal attention bias for local nd attention.
This assumes memory_flange is divisible by query_shape in every dimension.
Args:
query_shape: a n-d list of integers representing query shape
memory_flange: a n-d list of integers representing memory flange
decode_step: an integer
Returns:
a [1, 1, query_items, memory_items] tensor for masked attention bias or
a [1, 1, 1, memory_items] tensor if decode_step is not None.
"""
assert all([m % q == 0 for q, m in zip(query_shape, memory_flange)])
blocks_per_memory_flange = [
m // q for q, m in zip(query_shape, memory_flange)
]
# previous blocks will be half the number of all blocks if we select blocks
# to the left and right of center block in every dimension.
prev_blocks = np.prod([2 * b + 1 for b in blocks_per_memory_flange],
dtype=np.int32) // 2
all_blocks = np.prod(
[blocks_per_memory_flange[0] + 1] +
[2 * b + 1 for b in blocks_per_memory_flange[1:]],
dtype=np.int32)
future_blocks = all_blocks - prev_blocks - 1
# add unmasked biases for all prev blocks and a lower triangle for the center
# block and all masked for future blocks.
items_in_block = np.prod(query_shape, dtype=np.int32)
items_in_query = items_in_block if decode_step is None else 1
prev_blocks_attn = tf.zeros(
[1, 1, items_in_query, prev_blocks * items_in_block])
# add mask for the center block
if decode_step is None:
center_block_attn = attention_bias_lower_triangle(items_in_block)
else:
step_in_block = decode_step % items_in_block
cond = tf.reshape(
tf.less_equal(tf.range(items_in_block, dtype=tf.int32), step_in_block),
[1, 1, items_in_query, items_in_block])
center_block_attn = tf.where(
cond, tf.zeros([1, 1, items_in_query, items_in_block]),
-1e9 * tf.ones([1, 1, items_in_query, items_in_block]))
# add mask for all future blocks
future_blocks_attn = -1e9 * tf.ones(
[1, 1, items_in_query, future_blocks * items_in_block])
return tf.concat([prev_blocks_attn, center_block_attn, future_blocks_attn],
axis=3)
def compute_attention_component(antecedent,
total_depth,
filter_width=1,
padding="VALID",
name="c",
vars_3d_num_heads=0,
layer_collection=None):
"""Computes attention component (query, key or value).
Args:
antecedent: a Tensor with shape [batch, length, channels]
total_depth: an integer
filter_width: An integer specifying how wide you want the attention
component to be.
padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
name: a string specifying scope name.
vars_3d_num_heads: an optional integer (if we want to use 3d variables)
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
Returns:
c : [batch, length, depth] tensor
"""
if layer_collection is not None:
if filter_width != 1 or vars_3d_num_heads != 0:
raise ValueError(
"KFAC implementation only supports filter_width=1 (actual: {}) and "
"vars_3d_num_heads=0 (actual: {}).".format(
filter_width, vars_3d_num_heads))
if vars_3d_num_heads is not None and vars_3d_num_heads > 0:
assert filter_width == 1
input_depth = antecedent.get_shape().as_list()[-1]
depth_per_head = total_depth // vars_3d_num_heads
initializer_stddev = input_depth ** -0.5
if "q" in name:
initializer_stddev *= depth_per_head ** -0.5
var = tf.get_variable(
name, [input_depth,
vars_3d_num_heads,
total_depth // vars_3d_num_heads],
initializer=tf.random_normal_initializer(stddev=initializer_stddev))
var = tf.cast(var, antecedent.dtype)
var = tf.reshape(var, [input_depth, total_depth])
return tf.tensordot(antecedent, var, axes=1)
if filter_width == 1:
return common_layers.dense(
antecedent, total_depth, use_bias=False, name=name,
layer_collection=layer_collection)
else:
return common_layers.conv1d(
antecedent, total_depth, filter_width, padding=padding, name=name)
def compute_qkv(query_antecedent,
memory_antecedent,
total_key_depth,
total_value_depth,
q_filter_width=1,
kv_filter_width=1,
q_padding="VALID",
kv_padding="VALID",
vars_3d_num_heads=0,
layer_collection=None):
"""Computes query, key and value.
Args:
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: a Tensor with shape [batch, length_m, channels]
total_key_depth: an integer
total_value_depth: an integer
q_filter_width: An integer specifying how wide you want the query to be.
kv_filter_width: An integer specifying how wide you want the keys and values
to be.
q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
vars_3d_num_heads: an optional (if we want to use 3d variables)
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
Returns:
q, k, v : [batch, length, depth] tensors
"""
if memory_antecedent is None:
memory_antecedent = query_antecedent
q = compute_attention_component(
query_antecedent,
total_key_depth,
q_filter_width,
q_padding,
"q",
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
k = compute_attention_component(
memory_antecedent,
total_key_depth,
kv_filter_width,
kv_padding,
"k",
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
v = compute_attention_component(
memory_antecedent,
total_value_depth,
kv_filter_width,
kv_padding,
"v",
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
return q, k, v
def multihead_attention(query_antecedent,
memory_antecedent,
bias,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
dropout_rate,
attention_type="dot_product",
max_relative_position=None,
heads_share_relative_embedding=False,
add_relative_to_values=False,
image_shapes=None,
block_length=128,
block_width=128,
q_filter_width=1,
kv_filter_width=1,
q_padding="VALID",
kv_padding="VALID",
cache=None,
gap_size=0,
num_memory_blocks=2,
name="multihead_attention",
save_weights_to=None,
make_image_summary=True,
dropout_broadcast_dims=None,
vars_3d=False,
layer_collection=None,
recurrent_memory=None,
chunk_number=None,
hard_attention_k=0,
gumbel_noise_weight=0.0,
max_area_width=1,
max_area_height=1,
memory_height=1,
area_key_mode="mean",
area_value_mode="sum",
training=True,
**kwargs):
"""Multihead scaled-dot-product attention with input/output transformations.
Args:
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
bias: bias Tensor (see attention_bias())
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
attention_type: a string, either "dot_product", "dot_product_relative",
"local_mask_right", "local_unmasked", "masked_dilated_1d",
"unmasked_dilated_1d", graph, or any attention function
with the signature (query, key, value, **kwargs)
max_relative_position: Maximum distance between inputs to generate
unique relation embeddings for. Only relevant
when using "dot_product_relative" attention.
heads_share_relative_embedding: boolean to share relative embeddings
add_relative_to_values: a boolean for whether to add relative component to
values.
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
block_length: an integer - relevant for "local_mask_right"
block_width: an integer - relevant for "local_unmasked"
q_filter_width: An integer specifying how wide you want the query to be.
kv_filter_width: An integer specifying how wide you want the keys and values
to be.
q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
no padding.
cache: dict containing Tensors which are the results of previous
attentions, used for fast decoding. Expects the dict to contrain two
keys ('k' and 'v'), for the initial call the values for these keys
should be empty Tensors of the appropriate shape.
'k' [batch_size, 0, key_channels]
'v' [batch_size, 0, value_channels]
gap_size: Integer option for dilated attention to indicate spacing between
memory blocks.
num_memory_blocks: Integer option to indicate how many memory blocks to look
at.
name: an optional string.
save_weights_to: an optional dictionary to capture attention weights
for vizualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
vars_3d: use 3-dimensional variables for input/output transformations
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
recurrent_memory: An optional transformer_memory.RecurrentMemory, which
retains state across chunks. Default is None.
chunk_number: an optional integer Tensor with shape [batch] used to operate
the recurrent_memory.
hard_attention_k: integer, if > 0 triggers hard attention (picking top-k).
gumbel_noise_weight: if > 0, apply Gumbel noise with weight
`gumbel_noise_weight` before picking top-k. This is a no op if
hard_attention_k <= 0.
max_area_width: the max width allowed for an area.
max_area_height: the max height allowed for an area.
memory_height: the height of the memory.
area_key_mode: the mode for computing area keys, which can be "mean",
"concat", "sum", "sample_concat", and "sample_sum".
area_value_mode: the mode for computing area values, which can be either
"mean", or "sum".
training: indicating if it is in the training mode.
**kwargs (dict): Parameters for the attention function.
Caching:
WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
the caching assumes that the bias contains future masking.
The caching works by saving all the previous key and value values so that
you are able to send just the last query location to this attention
function. I.e. if the cache dict is provided it assumes the query is of the
shape [batch_size, 1, hidden_dim] rather than the full memory.
Returns:
The result of the attention transformation. The output shape is
[batch_size, length_q, hidden_dim]
unless the cache dict is provided in which case only the last memory
position is calculated and the output shape is [batch_size, 1, hidden_dim]
Optionally returns an additional loss parameters (ex: load balance loss for
the experts) returned by the attention_type function.
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
vars_3d_num_heads = num_heads if vars_3d else 0
if layer_collection is not None:
if cache is not None:
raise ValueError("KFAC implementation only supports cache is None.")
if vars_3d:
raise ValueError("KFAC implementation does not support 3d vars.")
if recurrent_memory is not None:
if memory_antecedent is not None:
raise ValueError("Recurrent memory requires memory_antecedent is None.")
if cache is not None:
raise ValueError("Cache is not supported when using recurrent memory.")
if vars_3d:
raise ValueError("3d vars are not supported when using recurrent memory.")
if layer_collection is not None:
raise ValueError("KFAC is not supported when using recurrent memory.")
if chunk_number is None:
raise ValueError("chunk_number is required when using recurrent memory.")
with tf.variable_scope(name, default_name="multihead_attention",
values=[query_antecedent, memory_antecedent]):
if recurrent_memory is not None:
(
recurrent_memory_transaction,
query_antecedent, memory_antecedent, bias,
) = recurrent_memory.pre_attention(
chunk_number,
query_antecedent, memory_antecedent, bias,
)
if cache is None or memory_antecedent is None:
q, k, v = compute_qkv(query_antecedent, memory_antecedent,
total_key_depth, total_value_depth, q_filter_width,
kv_filter_width, q_padding, kv_padding,
vars_3d_num_heads=vars_3d_num_heads,
layer_collection=layer_collection)
if cache is not None:
if attention_type not in ["dot_product", "dot_product_relative"]:
# TODO(petershaw): Support caching when using relative position
# representations, i.e. "dot_product_relative" attention.
raise NotImplementedError(
"Caching is not guaranteed to work with attention types other than"
" dot_product.")
if bias is None:
raise ValueError("Bias required for caching. See function docstring "
"for details.")
if memory_antecedent is not None:
# Encoder-Decoder Attention Cache
q = compute_attention_component(query_antecedent, total_key_depth,
q_filter_width, q_padding, "q",
vars_3d_num_heads=vars_3d_num_heads)
k = cache["k_encdec"]
v = cache["v_encdec"]
else:
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
decode_loop_step = kwargs.get("decode_loop_step")
if decode_loop_step is None:
k = cache["k"] = tf.concat([cache["k"], k], axis=2)
v = cache["v"] = tf.concat([cache["v"], v], axis=2)
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
# The performance of current implementation is better than updating
# the tensor by adding the result of matmul(one_hot,
# update_in_current_step)
tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
tmp_k = inplace_ops.alias_inplace_update(
tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
tmp_v = inplace_ops.alias_inplace_update(
tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])
q = split_heads(q, num_heads)
if cache is None:
k = split_heads(k, num_heads)
v = split_heads(v, num_heads)
key_depth_per_head = total_key_depth // num_heads
if not vars_3d:
q *= key_depth_per_head**-0.5
additional_returned_value = None
if callable(attention_type): # Generic way to extend multihead_attention
x = attention_type(q, k, v, **kwargs)
if isinstance(x, tuple):
x, additional_returned_value = x # Unpack
elif attention_type == "dot_product":
if max_area_width > 1 or max_area_height > 1:
x = area_attention.dot_product_area_attention(
q, k, v, bias, dropout_rate, image_shapes,
save_weights_to=save_weights_to,
dropout_broadcast_dims=dropout_broadcast_dims,
max_area_width=max_area_width,
max_area_height=max_area_height,
memory_height=memory_height,
area_key_mode=area_key_mode,
area_value_mode=area_value_mode,
training=training)
else:
x = dot_product_attention(
q, k, v, bias, dropout_rate, image_shapes,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=dropout_broadcast_dims,
activation_dtype=kwargs.get("activation_dtype"),
hard_attention_k=hard_attention_k,
gumbel_noise_weight=gumbel_noise_weight)
elif attention_type == "dot_product_relative":
x = dot_product_attention_relative(
q,
k,
v,
bias,
max_relative_position,
dropout_rate,
image_shapes,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
cache=cache is not None,
allow_memory=recurrent_memory is not None,
hard_attention_k=hard_attention_k,
gumbel_noise_weight=gumbel_noise_weight)
elif attention_type == "dot_product_unmasked_relative_v2":
x = dot_product_unmasked_self_attention_relative_v2(
q,
k,
v,
bias,
max_relative_position,
dropout_rate,
image_shapes,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=dropout_broadcast_dims,
heads_share_relative_embedding=heads_share_relative_embedding,
add_relative_to_values=add_relative_to_values)
elif attention_type == "dot_product_relative_v2":
x = dot_product_self_attention_relative_v2(
q,
k,
v,
bias,
max_relative_position,
dropout_rate,
image_shapes,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=dropout_broadcast_dims,
heads_share_relative_embedding=heads_share_relative_embedding,
add_relative_to_values=add_relative_to_values)
elif attention_type == "local_within_block_mask_right":
x = masked_within_block_local_attention_1d(
q, k, v, block_length=block_length)
elif attention_type == "local_relative_mask_right":
x = masked_relative_local_attention_1d(
q,
k,
v,
block_length=block_length,
make_image_summary=make_image_summary,
dropout_rate=dropout_rate,
heads_share_relative_embedding=heads_share_relative_embedding,
add_relative_to_values=add_relative_to_values,
name="masked_relative_local_attention_1d")
elif attention_type == "local_mask_right":
x = masked_local_attention_1d(
q,
k,
v,
block_length=block_length,
make_image_summary=make_image_summary)
elif attention_type == "local_unmasked":
x = local_attention_1d(
q, k, v, block_length=block_length, filter_width=block_width)
elif attention_type == "masked_dilated_1d":
x = masked_dilated_self_attention_1d(q, k, v, block_length, block_width,
gap_size, num_memory_blocks)
else:
assert attention_type == "unmasked_dilated_1d"
x = dilated_self_attention_1d(q, k, v, block_length, block_width,
gap_size, num_memory_blocks)
x = combine_heads(x)
# Set last dim specifically.
x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])
if vars_3d:
o_var = tf.get_variable(
"o", [num_heads, total_value_depth // num_heads, output_depth])
o_var = tf.cast(o_var, x.dtype)
o_var = tf.reshape(o_var, [total_value_depth, output_depth])
x = tf.tensordot(x, o_var, axes=1)
else:
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform",
layer_collection=layer_collection)
if recurrent_memory is not None:
x = recurrent_memory.post_attention(recurrent_memory_transaction, x)
if additional_returned_value is not None:
return x, additional_returned_value
return x
def multihead_attention_2d(query_antecedent,
memory_antecedent,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
attention_type="local_attention_2d",
query_shape=(8, 16),
memory_flange=(8, 16),
name=None):
"""2d Multihead scaled-dot-product attention with inp/output transformations.
Args:
query_antecedent: a Tensor with shape [batch, h, w, depth_k]
memory_antecedent: a Tensor with shape [batch, h, w, depth_k]
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
attention_type: String, type of attention function to use.
query_shape: an tuple indicating the height and width of each query block.
memory_flange: an integer indicating how much to look in height and width
name: an optional string
Returns:
A Tensor of shape [batch, h, w, output_depth]
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
with tf.variable_scope(
name,
default_name="multihead_attention_2d",
values=[query_antecedent, memory_antecedent]):
q, k, v = compute_qkv(query_antecedent, memory_antecedent, total_key_depth,
total_value_depth)
# after splitting, shape is [batch, heads, h, w, depth]
q = split_heads_2d(q, num_heads)
k = split_heads_2d(k, num_heads)
v = split_heads_2d(v, num_heads)
key_depth_per_head = total_key_depth // num_heads
q *= key_depth_per_head**-0.5
if attention_type == "local_attention_2d":
x = local_attention_2d(
q, k, v, query_shape=query_shape, memory_flange=memory_flange)
elif attention_type == "masked_local_attention_2d":
assert attention_type == "masked_local_attention_2d"
x = masked_local_attention_2d(
q, k, v, query_shape=query_shape, memory_flange=memory_flange)
else:
assert attention_type == "unmasked_local_attention_2d_tpu"
x = dot_product_unmasked_attention_local_2d_tpu(
q, k, v, None, max_relative_position=None, query_shape=query_shape)
x = combine_heads_2d(x)
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform")
return x
def multihead_attention_nd(query_antecedent,
memory_antecedent,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
query_shape,
memory_flange,
masked=False,
cache=None,
decode_step=None,
name=None):
"""n-d Multihead scaled-dot-product attention with in/output transformations.
Args:
query_antecedent: a Tensor with shape [batch, d1, ..., dn, depth_q] or
[batch, 1, ..., 1, depth_q] if in fast decoding mode.
memory_antecedent: a Tensor with shape [batch, d1, ..., dn, depth_m] or None
for self attention.
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
query_shape: an tuple indicating the dimensions of each query block.
memory_flange: an integer indicating how much to look around a query block
in each dimension
masked: a boolean to specify whether to do masked or unmasked attention.
cache: a dict like: {
'k': [batch, num_heads, d1, ..., dn, depth_k // num_heads],
'v': [batch, num_heads, d1, ..., dn, depth_v // num_heads]} Caller should
initially pass zero tensors for `decode_step` == 0. This method will
update cache and caller should pass the same cache in consecutive calls.
This works for both GPU and TPU inference. Caller should pass the latest
query via `query_antecedent`. `memory_antecedent` should be None in this
case, since auto-regressive decoding only applies to self attention.
decode_step: integer to pass in decoding mode. `cache` and `decode_step`
should both be set in decoding mode. Caller can also pass an empty `cache`
without `decode_step`, for this method to initialize the cache for future
calls with `decode_step` > 0.
name: an optional string
Returns:
A Tensor of shape [batch, d1, ..., dn, output_depth] or
[batch, 1, ..., 1, output_depth] if decode_step is set.
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
# Validate decoding input params are sensible.
if decode_step is not None:
assert "k" in cache and "v" in cache
if cache is not None:
assert memory_antecedent is None
with tf.variable_scope(
name,
default_name="multihead_attention_nd",
values=[query_antecedent, memory_antecedent]):
if decode_step is not None:
latest_antecedent = query_antecedent
q, latest_k, latest_v = compute_qkv(latest_antecedent, None,
total_key_depth, total_value_depth)
latest_k = split_heads_nd(latest_k, num_heads)
latest_v = split_heads_nd(latest_v, num_heads)
# put latest k and v into their correct position in cache.
k = cache["k"]
v = cache["v"]
k = put_item_in_decode_step(k, latest_k, decode_step, query_shape)
v = put_item_in_decode_step(v, latest_v, decode_step, query_shape)
cache["k"] = k
cache["v"] = v
else:
q, k, v = compute_qkv(query_antecedent, memory_antecedent,
total_key_depth, total_value_depth)
k = split_heads_nd(k, num_heads)
v = split_heads_nd(v, num_heads)
if cache is not None:
cache["k"] = k
cache["v"] = v
# after splitting, shape is [batch, heads, d1, ..., dn, depth]
q = split_heads_nd(q, num_heads)
key_depth_per_head = total_key_depth // num_heads
q *= key_depth_per_head**-0.5
if masked:
x = masked_local_attention_nd(
q,
k,
v,
query_shape=query_shape,
memory_flange=memory_flange,
decode_step=decode_step)
else:
raise NotImplementedError(
"Unmaked multihead attention nd is not implemented")
x = combine_heads_nd(x)
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform")
return x
def decode_step_to_index(decode_step, query_shape, tensor_shape):
"""Maps decode step to n-d index according to blocked raster scan order.
Args:
decode_step: an integer
query_shape: a tuple (q1, q2, ..., qn) representing the query shape
tensor_shape: a tuple (d1, d2, ..., dn) representing the tensor shape, minus
the batch and depth dimensions.
Returns:
a tuple (i1, i2, ..., in) representing the index of the element at
`decode_step` w.r.t. blocked raster scan order.
"""
assert len(query_shape) == len(tensor_shape)
blocks_per_dimension = [t // q for t, q in zip(tensor_shape, query_shape)]
items_in_block = np.prod(query_shape, dtype=np.int32)
step_block = decode_step // items_in_block
step_within_block = decode_step % items_in_block
block_index = []
for q in blocks_per_dimension[::-1]:
block_index.insert(0, step_block % q)
step_block //= q
within_block_index = []
for q in query_shape[::-1]:
within_block_index.insert(0, step_within_block % q)
step_within_block //= q
final_index = [
w + b * q for w, b, q in zip(within_block_index, block_index, query_shape)
]
return tuple(final_index)
def get_item_at_decode_step(x, decode_step, query_shape):
"""Extracts a single item from an n-d tensor at `decode_step` position.
Args:
x: a [batch, d1, d2, ..., dn, depth] tensor
decode_step: an integer
query_shape: a tuple (q1, q2, ..., qn) representing the query shape
Returns:
a [batch, 1, 1, ..., 1, depth] tensor that is a single element from `x` at
`decode_step` w.r.t. blocked raster scan order.
"""
x_shape = common_layers.shape_list(x)
index = decode_step_to_index(decode_step, query_shape, x_shape[1:-1])
# TPU needs size to be non negative for the case when begins are not
# compile-time constants.
return tf.slice(x, [0] + list(index) + [0],
[x_shape[0]] + [1] * len(index) + [x_shape[-1]])
def put_item_in_decode_step(x, item, decode_step, query_shape):
"""Puts a single item into an n-d tensor at `decode_step` position.
Args:
x: a [batch, heads, d1, d2, ..., dn, depth] tensor
item: a [batch, heads, 1, 1, ..., 1, depth] tensor
decode_step: an integer
query_shape: a tuple (q1, q2, ..., qn) representing the query shape
Returns:
a [batch, heads, d1, d2, ..., dn, depth] tensor with value at `decode_step`
w.r.t. blocked raster scan order is updated to be `item`.
"""
x_shape = common_layers.shape_list(x)
index = decode_step_to_index(decode_step, query_shape, x_shape[2:-1])
# inplace_update only works on the first dimension, we need to flatten and
# move batch to be the second dimension.
flattened_x = tf.reshape(
x, [-1, x_shape[1], np.prod(x_shape[2:-1]), x_shape[-1]])
# transpose to [positions, batch, heads, depth]
flattened_x = tf.transpose(flattened_x, [2, 0, 1, 3])
flattened_index = 0
factor = 1
for d, idx in zip(x_shape[-2:1:-1], index[::-1]):
flattened_index += idx * factor
factor *= d
item_shape = common_layers.shape_list(item)
item = tf.reshape(item, item_shape[:2] + item_shape[-1:])
updated_x = inplace_ops.alias_inplace_update(flattened_x, flattened_index,
item)
# unflatten the results
updated_x = tf.transpose(updated_x, [1, 2, 0, 3])
return tf.reshape(updated_x, [-1, x_shape[1]] + x_shape[2:])
def ffn_self_attention_layer(x,
filter_depth,
output_depth,
num_parts,
dropout_rate,
share_kv=False,
name=None):
"""Self-attention feedforward layer.
We use self-attention to do feedforward computations. We apply this function
positionwise where for each position, we linearly transform the output to have
depth filter_depth, and break up the result depth-wise into num_parts
contiguous parts. The parts self-attend, we concatenate the results
depth-wise, and we linearly transform to a depth of output_depth. The goal is
to get multiplicative interactions between components of a representation.
Args:
x: a Tensor with shape [batch, length, channels]
filter_depth: an integer
output_depth: an integer
num_parts: an integer dividing filter depth
dropout_rate: a floating point number
share_kv: Share the key value transform
name: an optional string
Returns:
A Tensor with shape [batch, length, output_depth].
"""
with tf.variable_scope(
name, default_name="feedforward_self_attention", values=[x]):
x_shape = common_layers.shape_list(x)
part_depth = filter_depth // num_parts
if not share_kv:
combined = common_layers.dense(
x, filter_depth * 3, use_bias=False, name="qkv_transform")
combined = tf.expand_dims(combined, axis=2)
q, k, v = tf.split(combined, 3, axis=3)
else:
q = tf.expand_dims(
common_layers.dense(
x, filter_depth, use_bias=False, name="q_transform"),
axis=2)
kv_combined = tf.expand_dims(
common_layers.dense(
tf.concat([x, x], axis=1),
filter_depth,
use_bias=False,
name="kv_transform"),
axis=2)
k, v = tf.split(kv_combined, [x_shape[1], x_shape[1]], axis=1)
batch_q = tf.reshape(q, [-1, 1, num_parts, part_depth])
batch_k = tf.reshape(k, [-1, 1, num_parts, part_depth])
batch_v = tf.reshape(v, [-1, 1, num_parts, part_depth])
batch_q *= part_depth**-0.5
# non-masked bias
bias = None
x = dot_product_attention(batch_q, batch_k, batch_v, bias, dropout_rate)
x = tf.reshape(x, [x_shape[0], x_shape[1], filter_depth])
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform")
return x
def parameter_attention(x,
total_key_depth,
total_value_depth,
output_depth,
memory_rows,
num_heads,
dropout_rate,
name=None):
"""Attention over parameters.
We use the same multi-headed attention as in the other layers, but the memory
keys and values are model parameters. There are no linear transformation on
the keys or values.
We are also a bit more careful about memory usage, since the number of
memory positions may be very large.
Args:
x: a Tensor with shape [batch, length_q, channels]
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
memory_rows: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
name: an optional string
Returns:
A Tensor with shape [batch, length_q, output_depth].
"""
with tf.variable_scope(name, default_name="parameter_attention", values=[x]):
head_size_k = total_key_depth // num_heads
head_size_v = total_value_depth // num_heads
var_shape_k = [num_heads, memory_rows, head_size_k]
var_shape_v = [num_heads, memory_rows, head_size_v]
k = tf.get_variable(
"k",
var_shape_k,
initializer=tf.random_normal_initializer(
0, output_depth**-0.5 * (num_heads**0.5)))
v = tf.get_variable(
"v",
var_shape_v,
initializer=tf.random_normal_initializer(
0, output_depth**-0.5 * (output_depth**0.5)))
batch_size = common_layers.shape_list(x)[0]
length = common_layers.shape_list(x)[1]
q = common_layers.dense(
x, total_key_depth, use_bias=False, name="q_transform")
if dropout_rate:
# This is a cheaper form of attention dropout where we use to use
# the same dropout decisions across batch elements and query positions,
# but different decisions across heads and memory positions.
v = tf.nn.dropout(
v, 1.0 - dropout_rate, noise_shape=[num_heads, memory_rows, 1])
# query is [batch, length, hidden_size]
# reshape and transpose it to [heads, batch * length, head_size]
q = tf.reshape(q, [batch_size, length, num_heads, head_size_k])
q = tf.transpose(q, [2, 0, 1, 3])
q = tf.reshape(q, [num_heads, batch_size * length, head_size_k])
weights = tf.matmul(q, k, transpose_b=True)
weights = tf.nn.softmax(weights)
y = tf.matmul(weights, v)
y = tf.reshape(y, [num_heads, batch_size, length, head_size_v])
y = tf.transpose(y, [1, 2, 0, 3])
y = tf.reshape(y, [batch_size, length, total_value_depth])
y.set_shape([None, None, total_value_depth])
y = common_layers.dense(
y, output_depth, use_bias=False, name="output_transform")
return y
@expert_utils.add_name_scope()
def coordinate_tensor(shape, axis):
"""Return a tensor with given shape containing coordinate along given axis.
Args:
shape: a Tensor representing the shape of the output Tensor
axis: an integer
Returns:
A tensor with shape shape and type tf.int32, where each elements its
coordinate along the given axis.
"""
if axis < 0:
axis = tf.size(shape) + axis # Convert to positive for the one_hot indice
r = tf.range(shape[axis])
r_shape = tf.one_hot(
axis, tf.size(shape), on_value=-1, off_value=1, dtype=tf.int32)
return tf.zeros(shape, dtype=tf.int32) + tf.reshape(r, r_shape)
def self_attention_expert(x,
batch_coordinate,
mask_right=True,
split_batch=False,
attention_num_head=1,
attention_kq_size=None,
attention_v_size=None):
"""Implementing attention that runs inside each expert.
Args:
x: A tensor of shape[batch, depth]. Contains representations from
different positions, which are lexicographically ordered.
batch_coordinate: A tensor of shape [batch, 1] containing the batch
coordinate of each element in x. This is needed to make sure that
positions from different sequences don't attend to each other.
mask_right: A bool. If true, we will not attend to positions on the right,
just as decoder self attention.
split_batch (bool): If True, each sequence of the batch is processed
individually on a loop. If False, the sequences are processed all at
once and a mask is applied to isolate the sequences from each others
attention_num_head (int): number of attention heads
attention_kq_size (int): dimension used for the attention key, and query
attention_v_size (int): dimension used for the attention value
Returns:
out: A tensor of shape [batch, depth].
example use:
expert_utils.local_moe(
...
expert_fn=functools.partial(self_attention_expert, mask_right=)
)
"""
depth = x.get_shape().as_list()[-1]
length = common_layers.shape_list(batch_coordinate)[0]
# Print a warning message if one of the expert isn't used (useful at
# inference where summaries aren't used and the gating function don't add
# noise)
global _expert_count # Hack to make each expert have a unique id
_expert_count += 1
length = tf.cond(
tf.equal(length, 0),
lambda: tf.Print( # pylint: disable=g-long-lambda
length, [length], "Expert {} empty: ".format(_expert_count)),
lambda: length,
)
tf.summary.scalar("batch_size", length, family="experts_stats_batch_size")
attention_kq_size = attention_kq_size or depth
attention_v_size = attention_v_size or depth
def length_not_null(x, batch_coordinate):
"""Branch of the graph only evaluated when length isn't null."""
# Mask between the sequences (not used if map_ids is used)
bias_batch = attention_bias_coordinates(batch_coordinate)
def add_or_set_if(prev_bias, new_bias, condition):
"""Add the bias together while considering the None case."""
if not condition:
return prev_bias
if prev_bias is None:
return new_bias
return prev_bias + new_bias
def mask_and_call_attention(x):
"""Function applied once for each sequence of the batch."""
# Mask to prevent sequences of attending to the future
length = common_layers.shape_list(x)[1] # x has shape [1, length,...]
bias_past = tf.reshape(
attention_bias_lower_triangle(length), [length, length])
# bias has shape [length, length]
bias = None
bias = add_or_set_if(bias, bias_past, mask_right)
bias = add_or_set_if(bias, bias_batch, not split_batch)
bias = tf.reshape(bias, [1, 1, length, length])
return multihead_attention(
x,
None,
bias,
total_key_depth=attention_kq_size,
total_value_depth=attention_v_size,
output_depth=depth,
num_heads=attention_num_head,
dropout_rate=0.0)
if split_batch:
out = expert_utils.map_ids(x, batch_coordinate, mask_and_call_attention)
else:
x = tf.reshape(x, [1, length, depth])
out = mask_and_call_attention(x)
out = tf.squeeze(out, 0)
return out
# If the length is empty, just forward an empty tensor (avoid having to
# evaluate multihead_attention with tensor having dim equal to zeros)
out = tf.cond(
tf.equal(length, 0),
lambda: tf.zeros(shape=[0, depth], dtype=tf.float32, name="empty_out"),
lambda: length_not_null(x, batch_coordinate),
)
return out
def local_expert_attention(x,
k,
loss_coef,
attention_num_experts,
train=True,
batch_coordinate=None,
**kwargs):
"""Attention using a mixture of experts.
Positions sent to the same expert can attend to each other.
The mixture of experts is "local" in that it is replicated on each
datashard.
local_moe flatten all batches so to avoid problems with padding (ex: all
padding going to the same expert, self attention attending to non null
padding tokens,...), the padding should be removed before.
Args:
x: a Tensor with shape [batch, length, depth] or [1, batch*length, depth]
k: The number of experts to dispatch each example to
loss_coef: a scalar. A multiplier for the expert loss
attention_num_experts: The number of experts to use
train: a boolean for the current mode
batch_coordinate (tf.Tensor): int32 tensor of shape [1, batch*length, 1]
containing the batch ids. If None, deduced from first dim of x.
**kwargs: Arguments to forward to self_attention_expert
Returns:
y: a Tensor with shape [batch, length, depth]
loss: a Scalar
"""
if batch_coordinate is None:
batch_coordinate = tf.expand_dims(
coordinate_tensor(common_layers.shape_list(x)[:-1], axis=0), axis=-1)
with tf.variable_scope("local_expert_attention"):
additional_dispatch_params = {"batch_coordinate": batch_coordinate}
return expert_utils.local_moe(
x,
train,
functools.partial(self_attention_expert, **kwargs),
attention_num_experts,
k=k,
loss_coef=loss_coef,
pass_x=True,
pass_gates=False,
additional_dispatch_params=additional_dispatch_params,
)
@expert_utils.add_name_scope()
def expert_dot_product(q, k, v, info_q, info_k):
"""Perform dot product on a subset of the sequence.
Can add a mask to the attention to prevent sequences to attend to each other
and to prevent attention to the future.
Args:
q (tf.Tensor): Queries of shape [length_expert_q, depth_k]
k (tf.Tensor): Keys of shape [length_expert_k, depth_k]
v (tf.Tensor): Values of shape [length_expert_k, depth_v]
info_q (BatchInfo): Batch info for queries. If None, no mask is added
info_k (BatchInfo): Batch info for keys
Returns:
tf.Tensor: dot product attention output ([length_expert_q, depth_v])
"""
length_q = common_layers.shape_list(q)[0]
length_k = common_layers.shape_list(k)[0]
depth_v = v.get_shape().as_list()[-1]
# Create the mask
bias = attention_bias_coordinates(info_q.coordinates, info_k.coordinates)
if info_k.order is not None:
bias += attention_bias_future(info_q.order, info_k.order)
# Restore batch and head dimension
q, k, v = [tf.expand_dims(tf.expand_dims(t, 0), 0) for t in (q, k, v)]
def is_zero():
zeros = tf.zeros(shape=[1, 1, length_q, depth_v], dtype=tf.float32)
zeros = tf.Print(zeros, [length_k, length_q], "length_k/length_q: ")
return zeros
def is_not_zero():
return dot_product_attention(
q,
k,
v,
bias=bias,
# No image summary to avoid "Retval[0] does not have value" (because
# inside a condition)
make_image_summary=False,
)
# TODO(epot): Should make sure a query gets at least one key. Because the
# different sequences of a batch are merged, it's possible that a
# query from a sequence only receive memory from another sequence, so
# with the mask, the query will perform a softmax on -infinity values.
# A hack could be to add at least one sequence of each batch on each group so
# the query can attend to at least one element.
# Softmax(Q.K)*V
v_out = tf.cond(
tf.logical_or(tf.equal(length_q, 0), tf.equal(length_k, 0)),
is_zero,
is_not_zero,
)
# Remove batch and head dimension
v_out = tf.squeeze(v_out, axis=0)
v_out = tf.squeeze(v_out, axis=0)
return v_out
@expert_utils.add_name_scope()
def dot_product_single_head(q, k, v, gates_q, gates_k, bi):
"""Perform a dot product attention on a single sequence on a single head.
This function dispatch the q, k, v and loop over the buckets to compute the
attention dot product on each subsequences.
Args:
q (tf.Tensor): [length_q, depth_q]
k (tf.Tensor): [length_k, depth_q]
v (tf.Tensor): [length_k, depth_v]
gates_q (tf.Tensor): One-hot vector of shape [length_q, nb_buckets]
gates_k (tf.Tensor): One-hot vector of shape [length_k, nb_buckets]
bi (BatchInfo): Contains the batch coordinates and sequence order
Returns:
tf.Tensor: [length_q, depth_v]
"""
nb_buckets = gates_q.get_shape().as_list()[-1]
q_dispatcher = expert_utils.SparseDispatcher(nb_buckets, gates_q)
k_dispatcher = expert_utils.SparseDispatcher(nb_buckets, gates_k)
def eventually_dispatch(dispatcher, value):
if value is not None:
return dispatcher.dispatch(value)
return [None] * nb_buckets
# Iterate over every dispatched group
list_v_out = []
for (
q_i,
k_i,
v_i,
qbc,
qbo,
kbc,
kbo,
) in zip(
# Dispatch queries, keys and values
q_dispatcher.dispatch(q),
k_dispatcher.dispatch(k),
k_dispatcher.dispatch(v),
# Also dispatch the sequence positions and batch coordinates
eventually_dispatch(q_dispatcher, bi.coordinates),
eventually_dispatch(q_dispatcher, bi.order),
eventually_dispatch(k_dispatcher, bi.coordinates),
eventually_dispatch(k_dispatcher, bi.order),
):
list_v_out.append(
expert_dot_product(
q_i,
k_i,
v_i,
info_q=BatchInfo(coordinates=qbc, order=qbo),
info_k=BatchInfo(coordinates=kbc, order=kbo)))
# Combine all buckets together to restore the original length
return q_dispatcher.combine(list_v_out)
def map_fn_switch(fn, elems, use_map_fn=True, **kwargs):
"""Construct the graph with either tf.map_fn or a python for loop.
This function is mainly for for benchmarking purpose.
tf.map_fn is dynamic but is much slower than creating a static graph with
for loop. However, having a for loop make the graph much longer to build
and can consume too much RAM on distributed setting.
Args:
fn (fct): same that tf.map_fn but for now can only return a single tensor
value (instead of a tuple of tensor for the general case)
elems (tuple): same that tf.map_fn
use_map_fn (bool): If True, tf.map_fn is used, if False, for _ in _: is used
instead
**kwargs: Additional tf.map_fn arguments (ignored if use_map_fn is False)
Returns:
tf.Tensor: the output of tf.map_fn
"""
if use_map_fn:
return tf.map_fn(fn, elems, **kwargs)
elems_unpacked = (tf.unstack(e) for e in elems)
out_unpacked = [fn(e) for e in zip(*elems_unpacked)]
out = tf.stack(out_unpacked)
return out
@expert_utils.add_name_scope()
def sparse_dot_product_attention(q, k, v, bi, use_map_fn, experts_params):
"""Sparse multihead self attention.
Perform an approximation of the full multihead attention by dispatching
the tokens using their keys/values. Thus the attention matrix are only
computed each times on a subset of the tokens.
Notes:
* The function don't perform scaling here (multihead_attention does
the /sqrt(depth)).
* The padding should have been removed (so batch size should be 1 but length
contains the elements from all different batches)
* Right now, only self attention is supported so length_q and length_kv
should be identical and the function will add triangular mask.
* If bi.order is not None, The bias is added inside this function to
prevent attention to the future.
Args:
q (tf.Tensor): Queries of shape [batch, heads, length_q, depth_k]
k (tf.Tensor): Keys of shape [batch, heads, length_q, depth_k]
v (tf.Tensor): Values of shape [batch, heads, length_kv, depth_v]
bi (BatchInfo): Contains the batch coordinates and sequence order
use_map_fn (bool): Use either tf.map_fn of python for loop to compute the
heads separately
experts_params (dict): Additional params for the local expert
Returns:
tf.Tensor: Approximation of Softmax(Q.K) * V, of shape
[batch, heads, length_q, depth_v]
"""
batch_size, nb_heads, _, depth = common_layers.shape_list(q)
@expert_utils.add_name_scope()
def flatten_first_dims(x):
"""Reshape such that x is [num_heads, -1, depth]."""
# Case 1: Either constant batch size of size 1 or batch already flattened
if x.get_shape().as_list()[0] == 1:
return tf.squeeze(x, axis=0)
# Case 2: Flatten batch dimension
x = tf.transpose(x, perm=[1, 0, 2, 3])
x = tf.reshape(x, [nb_heads, -1, depth])
return x
def flatten_batch(x):
if x is None:
return x
return expert_utils.flatten_all_but_last(x)
q = flatten_first_dims(q)
k = flatten_first_dims(k)
v = flatten_first_dims(v)
bi = BatchInfo(
coordinates=flatten_batch(bi.coordinates),
order=flatten_batch(bi.order),
)
# Unstack heads
list_q = tf.unstack(q) # list[tf.Tensor(shape=[batch * length, depth])]
list_k = tf.unstack(k)
list_v = tf.unstack(v)
list_gates_q = []
list_gates_k = []
total_loss = 0.0
# There might be a more optimized way to compute all heads at once
for single_q, single_k, _ in zip(list_q, list_k, list_v):
# Each head get its own dispatcher
lhs_gating = LshGating(
depth=single_q.get_shape().as_list()[-1], **experts_params)
list_gates_q.append(lhs_gating.get_gates(single_q))
list_gates_k.append(lhs_gating.get_gates(single_k))
gates_q = tf.stack(list_gates_q)
gates_k = tf.stack(list_gates_k)
# Process each head separately.
v_out = map_fn_switch(
lambda args: dot_product_single_head(bi=bi, *args),
elems=(q, k, v, gates_q, gates_k),
dtype=(tf.float32),
parallel_iterations=2,
use_map_fn=use_map_fn,
)
# Restore original shape as expected by multihead_attention
if isinstance(batch_size, int) and batch_size == 1:
v_out = tf.expand_dims(v_out, axis=0) # Restore batch_size = 1
else:
v_out = tf.reshape(v_out, [nb_heads, batch_size, -1, depth])
v_out = tf.transpose(v_out, [1, 0, 2, 3])
return v_out, total_loss / nb_heads
@expert_utils.add_name_scope()
def dot_product_batched_head(q, k, v, gates_q, gates_k, mask_right=False):
"""Perform a dot product attention on a single sequence on a single head.
This function dispatch the q, k, v and loop over the buckets to compute the
attention dot product on each subsequences.
Args:
q (tf.Tensor): [batch*heads, length_q, depth_q]
k (tf.Tensor): [batch*heads, length_k, depth_q]
v (tf.Tensor): [batch*heads, length_k, depth_v]
gates_q (tf.Tensor): One-hot of shape [batch*heads, length_q, nb_buckets]
gates_k (tf.Tensor): One-hot of shape [batch*heads, length_k, nb_buckets]
mask_right (bool): Add a bias to prevent attention to the future
Returns:
tf.Tensor: [length_q, depth_v]
"""
nb_buckets = common_layers.shape_list(gates_q)[-1]
@expert_utils.add_name_scope()
def get_dispatcher(gates):
"""Construct dispatcher for gates."""
length = common_layers.shape_list(gates)[1]
# Count the number of ones per batch (and keep the max value)
nb_elems_to_dispatch = tf.reduce_sum(gates, axis=[1, 2])
nb_elems_to_dispatch = tf.reduce_max(nb_elems_to_dispatch)
nb_elems_to_dispatch = tf.to_int32(nb_elems_to_dispatch)
capacity = nb_elems_to_dispatch // nb_buckets * 2 # Capacity is hardcoded
capacity = tf.minimum(length, capacity)
tf.summary.scalar("dispatch_capacity", capacity, family="lsh")
return expert_utils.TruncatingDispatcher(gates, capacity)
def add_summary_capacity(x, prefix):
# Monitor if capacity overflow
x = x[0, ...] # Take first batch/head
x = tf.reduce_sum(x, axis=0)
tf.summary.scalar(prefix + "_min", tf.reduce_min(x), family="lsh")
tf.summary.scalar(prefix + "_max", tf.reduce_max(x), family="lsh")
tf.summary.histogram(prefix + "capacity_distribution", x, family="lsh")
for i in range(3): # Show the first 3 buckets
tf.summary.scalar("{}_{}".format(prefix, i), x[i], family="lsh")
add_summary_capacity(gates_q, "q")
add_summary_capacity(gates_k, "k")
q_dispatcher = get_dispatcher(gates_q)
k_dispatcher = get_dispatcher(gates_k)
q = q_dispatcher.dispatch(q)
k = k_dispatcher.dispatch(k)
v = k_dispatcher.dispatch(v)
# Bias of shape [batch*heads, nb_buckets, 1, capacity] broadcasted to every
# queries
bias = tf.expand_dims((k_dispatcher.nonpadding() - 1.0) * 1e9, 2)
if mask_right:
q_coordinate = tf.to_float(
tf.expand_dims(q_dispatcher.length_coordinate(), 3))
k_coordinate = tf.to_float(
tf.expand_dims(k_dispatcher.length_coordinate(), 2))
bias += tf.to_float(tf.greater(k_coordinate, q_coordinate)) * -1e9
# The sequence padding is not masked but is ignored on the next layers
# q, k, v now have shape [batch*heads, nb_bucket, capacity, depth]
# The buckets can be seen as different heads
v_out = dot_product_attention(q, k, v, bias=bias)
# Combine all buckets together to restore the original length
return q_dispatcher.combine(v_out)
@expert_utils.add_name_scope()
def sparse_dot_product_attention_truncated(
q,
k,
v,
bi, # Unused
experts_params,
use_map_fn=False, # Unused
mask_right=False,
): # pylint: disable=unused-argument
"""Sparse multihead self attention.
Perform an approximation of the full multihead attention by dispatching
the tokens using their keys/values. Thus the attention matrix are only
computed each times on a subset of the tokens.
Notes:
* The function don't perform scaling here (multihead_attention does
the /sqrt(depth)).
* The padding should have been removed (so batch size should be 1 but length
contains the elements from all different batches)
* Right now, only self attention is supported so length_q and length_kv
should be identical and the function will add triangular mask.
* If bi.order is not None, The bias is added inside this function to
prevent attention to the future.
Args:
q (tf.Tensor): Queries of shape [batch, heads, length_q, depth_k]
k (tf.Tensor): Keys of shape [batch, heads, length_q, depth_k]
v (tf.Tensor): Values of shape [batch, heads, length_kv, depth_v]
bi (BatchInfo): Contains the batch coordinates and sequence order
experts_params (dict): Additional params for the local expert
use_map_fn (bool): Use either tf.map_fn of python for loop to compute the
heads separately
mask_right (bool):
Returns:
tf.Tensor: Approximation of Softmax(Q.K) * V, of shape
[batch, heads, length_q, depth_v]
"""
# Currently depth is the same for for q and v
batch_size, nb_heads, _, depth = common_layers.shape_list(q)
total_loss = 0.0
# Each head get its own dispatcher
list_lsh = [LshGating(depth=depth, **experts_params) for _ in range(nb_heads)]
@expert_utils.add_name_scope()
def get_gates_head(x, add_first=False):
"""Return the gates for each heads of the current x.
Args:
x (tf.Tensor): of shape [batch, heads, length, depth]
add_first (bool): if True, add the first element on each bucket
Returns:
tf.Tensor: gates of shape [batch, heads, length, num_buckets]
"""
length = common_layers.shape_list(x)[2]
# Invert heads/batch
x = tf.transpose(x, perm=[1, 0, 2, 3])
x = tf.reshape(x, [nb_heads, batch_size * length, depth])
list_x = tf.unstack(x) # list[tf.Tensor(shape=[batch * length, depth])]
# Unstack heads
list_gates = []
# There might be a more optimized way to compute all heads at once
for lsh, single_x in zip(list_lsh, list_x):
# Each head get its own dispatcher
gates = lsh.get_gates(single_x)
nb_buckets = gates.get_shape().as_list()[-1]
# Reshape to [batch, length, depth] but should consider sequence
# padding in that case (also dispatch the padding)
gates = tf.reshape(gates, [batch_size, length, nb_buckets])
list_gates.append(gates)
gates = tf.stack(list_gates)
# Restore original shape
gates = tf.reshape(gates, [nb_heads, batch_size, length, nb_buckets])
gates = tf.transpose(gates, [1, 0, 2, 3])
# Dispatch the first element to every gates to avoid empty buckets
if add_first:
gates = tf.maximum(gates,
tf.reshape(tf.one_hot([0], length), [1, 1, length, 1]))
return gates
gates_q = get_gates_head(q)
gates_k = get_gates_head(k, add_first=True)
# [batch, heads, length, depth] => [batch*heads, length, depth]
q, k, v, gates_q, gates_k = [
combine_first_two_dimensions(t) for t in (q, k, v, gates_q, gates_k)
]
v_out = dot_product_batched_head(q, k, v, gates_q, gates_k, mask_right)
# Restore original dimension
v_out = tf.reshape(v_out, [batch_size, nb_heads, -1, depth])
return v_out, total_loss / nb_heads
@expert_utils.add_var_scope()
def deconv_elems_1d(x, factor, out_depth=None):
"""Increase the length and change the dimensionality.
Expand/project each positions of dim depth of the input into
factor*tokens of dim out_depth
Args:
x (tf.Tensor): shape [batch_size, length, depth]
factor (int): Multiplicative factor of each tokens.
out_depth (int): Output depth (if None, keep depth constant)
Returns:
tf.Tensor: shape [batch_size, length*factor, out_depth]
"""
out_depth = out_depth or x.get_shape().as_list()[-1]
x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth]
x = layers().Conv2DTranspose(
filters=out_depth,
kernel_size=(1, factor),
strides=(1, factor),
padding="valid",
data_format="channels_last",
)(x) # [batch_size, 1, length*factor, out_depth]
x = tf.squeeze(x, 1) # [batch_size, length*factor, depth]
return x
@expert_utils.add_var_scope()
def conv_elems_1d(x, factor, out_depth=None):
"""Decrease the length and change the dimensionality.
Merge/restore/compress factors positions of dim depth of the input into
a single position of dim out_depth.
This is basically just a strided convolution without overlap
between each strides. The original length has to be divided by factor.
Args:
x (tf.Tensor): shape [batch_size, length, depth]
factor (int): Length compression factor.
out_depth (int): Output depth
Returns:
tf.Tensor: shape [batch_size, length//factor, out_depth]
"""
out_depth = out_depth or x.get_shape().as_list()[-1]
# with tf.control_dependencies( # Dynamic assertion
# [tf.assert_equal(tf.shape(x)[1] % factor, 0)]):
x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth]
x = layers().Conv2D(
filters=out_depth,
kernel_size=(1, factor),
strides=(1, factor),
padding="valid",
data_format="channels_last",
)(x) # [batch_size, 1, length//factor, out_depth]
x = tf.squeeze(x, 1) # [batch_size, length//factor, depth]
return x
@expert_utils.add_var_scope()
def local_reduction_attention(x, block_length, multihead_params):
"""Reduce the length dimension using self attention.
Args:
x (tf.Tensor): float32 of shape [batch, length, depth]
block_length (int): Block length for local attention (Compression factor)
multihead_params (dict): parameters for multihead attention
Returns:
tf.Tensor: Compressed tensor of shape [batch, length // factor, depth]
"""
@expert_utils.add_name_scope()
def dot_product_self_local_attention_flattened(q, k, v):
"""Strided block local self-attention.
No overlap between the blocks.
Args:
q (tf.Tensor): shape [batch, heads, length, depth_k]
k (tf.Tensor): shape [batch, heads, length, depth_k]
v (tf.Tensor): shape [batch, heads, length, depth_v]
Returns:
tf.Tensor: shape [batch, heads, length, depth_v]
"""
_, num_head, _, depth = q.get_shape().as_list()
# Extract the blocks
def pad_and_reshape(x):
"""Split the length dim into [num_block, block_length]."""
length_x = common_layers.shape_list(x)[2]
# Add some padding, but won't matter as the last block will never be
# attended by the query (after compression)
x = tf.pad(x, [[0, 0], [0, 0], [0, -length_x % block_length], [0, 0]])
x = tf.reshape(
x,
[
common_layers.shape_list(x)[0], # Batch
num_head, # Head
common_layers.shape_list(x)[2] // block_length, # Num blocks
block_length, # Block length
depth, # Depth
])
return x
q, k, v = [pad_and_reshape(t) for t in (q, k, v)]
# Perform attention on the flattened dot product
logits = tf.matmul(q, k, transpose_b=True)
logits = tf.reshape(
logits,
[
common_layers.shape_list(logits)[0], # Batch
num_head, # Head
common_layers.shape_list(logits)[2], # Num blocks
block_length**2, # Flatten last dimension
])
weights = tf.nn.softmax(logits)
weights = tf.reshape(
weights,
[
common_layers.shape_list(weights)[0], # Batch
num_head, # Head
common_layers.shape_list(weights)[2], # Num blocks
block_length,
block_length, # Restore the block length dimension
])
weights = tf.reduce_sum(weights, axis=3, keep_dims=True) # Compress block
v_out = tf.matmul(weights, v) # [1, block_length] @ [block_length, depth]
v_out = tf.squeeze(v_out, axis=3)
return v_out
return multihead_attention(
x,
None,
bias=None,
output_depth=x.get_shape().as_list()[-1],
attention_type=dot_product_self_local_attention_flattened,
**multihead_params)
@expert_utils.add_var_scope()
def multihead_self_attention_reduced(
x,
memory_antecedent=None,
bias=None,
factor=None,
multihead_params=None,
nonlinearity="none",
reduction_type="conv",
add_mask=True,
):
"""Reduce the length dimension by compressing with conv.
Args:
x (tf.Tensor): float32 of shape [batch, length, depth]
memory_antecedent (tf.Tensor): Unsupported for now
bias (tf.Tensor): Ignored
factor (int): compression factor for the memory sequence
multihead_params (dict): parameters for multihead attention
nonlinearity (str): Add some non-linearity after the memory block
reduction_type (str): type of compression
add_mask (bool): If True, add the bias to prevent attention to the future
Returns:
(tf.Tensor): float32 of shape [batch, length, depth]
Raises:
ValueError: If reduction_type or nonlinearity is invalid
"""
if not factor or not multihead_params:
raise ValueError("factor and multihead_params should be set")
if memory_antecedent is not None:
raise NotImplementedError(
"multihead_self_attention_reduced only works with self-attention")
depth = x.get_shape().as_list()[-1]
# Could try to have some overlap between the blocks but that would
# create conv artifacts, would make it difficult to not attend to the future
# within one group and the padding should be handled specially.
# Reduce the memory dimension
if reduction_type == "attention":
memory_x = local_reduction_attention(x, factor, multihead_params)
elif reduction_type == "conv":
# With valid padding, the last block won't be computed (not attended anyway)
memory_x = conv_elems_1d(x, factor)
else:
raise ValueError("Unknown reduction type {}".format(reduction_type))
if nonlinearity == "silu":
memory_x *= tf.nn.sigmoid(memory_x)
elif nonlinearity != "none":
raise ValueError("Unknown non linearity {}".format(nonlinearity))
memory_x = tf.concat(
# Add the first elem to make it attendable by everyone (otherwise the
# first block cannot attend to anything)
[x[:, :1, :], memory_x],
axis=1,
)
# Construct the bias
@expert_utils.add_name_scope()
def construct_bias_vectors(t, axis):
length = tf.to_float(common_layers.shape_list(t)[1])
length_coordinates = tf.range(length, dtype=tf.float32)
length_coordinates = tf.expand_dims(length_coordinates, axis=axis)
# [1, length_k] or [length_q, 1]
return length_coordinates
if add_mask: # Create mask to prevent attention to the future
bias = tf.to_float(
tf.greater(
# Because we add the first elem to the memory block and it can be
# attended by anyone,we don't need to add +1 anymore to prevent self
# attention Use * factor to make sure the last tokens of a block
# cannot attend the block
construct_bias_vectors(memory_x, 0) * factor,
# +epsilon to avoid float equality
construct_bias_vectors(x, 1) + 1e-3,
)) * -1e9
bias = tf.expand_dims(bias, axis=0)
bias = tf.expand_dims(bias, axis=0) # [1, 1, length_k, length_q]
else:
bias = None
return multihead_attention(
query_antecedent=x,
memory_antecedent=memory_x,
bias=bias,
output_depth=depth,
**multihead_params)
def scaled_dot_product_attention_simple(q, k, v, bias, name=None):
"""Scaled dot-product attention. One head. One spatial dimension.
Args:
q: a Tensor with shape [batch, length_q, depth_k]
k: a Tensor with shape [batch, length_kv, depth_k]
v: a Tensor with shape [batch, length_kv, depth_v]
bias: optional Tensor broadcastable to [batch, length_q, length_kv]
name: an optional string
Returns:
A Tensor.
"""
with tf.variable_scope(
name, default_name="scaled_dot_product_attention_simple"):
scalar = tf.rsqrt(tf.to_float(common_layers.shape_list(q)[2]))
logits = tf.matmul(q * scalar, k, transpose_b=True)
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if common_layers.should_generate_summaries():
tf.summary.image(
"attention", tf.expand_dims(tf.pow(weights, 0.2), 3), max_outputs=1)
return tf.matmul(weights, v)
_function_cache = {}
def multihead_self_attention_memory_efficient(x,
bias,
num_heads,
head_size=None,
epsilon=1e-6,
forget=True,
test_vars=None,
name=None):
"""Multihead scaled-dot-product self-attention.
Includes layer norm.
Returns multihead-self-attention(layer_norm(x))
Computes one attention head at a time to avoid exhausting memory.
If forget=True, then forget all forwards activations and recompute on
the backwards pass.
Args:
x: a Tensor with shape [batch, length, input_size]
bias: an attention bias tensor broadcastable to [batch, 1, length, length]
num_heads: an integer
head_size: an optional integer - defaults to input_size/num_heads
epsilon: a float, for layer norm
forget: a boolean - forget forwards activations and recompute on backprop
test_vars: optional tuple of variables for testing purposes
name: an optional string
Returns:
A Tensor.
"""
io_size = x.get_shape().as_list()[-1]
if head_size is None:
assert io_size % num_heads == 0
head_size = io_size / num_heads
def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
"""Forward function."""
n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias)
wqkv_split = tf.unstack(wqkv, num=num_heads)
wo_split = tf.unstack(wo, num=num_heads)
y = 0
for h in range(num_heads):
with tf.control_dependencies([y] if h > 0 else []):
combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
q, k, v = tf.split(combined, 3, axis=2)
o = scaled_dot_product_attention_simple(q, k, v, attention_bias)
y += tf.nn.conv1d(o, wo_split[h], 1, "SAME")
return y
key = (
"multihead_self_attention_memory_efficient %s %s" % (num_heads, epsilon))
if not forget:
forward_fn = forward_internal
elif key in _function_cache:
forward_fn = _function_cache[key]
else:
@function.Defun(compiled=True)
def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy):
"""Custom gradient function."""
with tf.control_dependencies([dy]):
n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias)
wqkv_split = tf.unstack(wqkv, num=num_heads)
wo_split = tf.unstack(wo, num=num_heads)
deps = []
dwqkvs = []
dwos = []
dn = 0
for h in range(num_heads):
with tf.control_dependencies(deps):
combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
q, k, v = tf.split(combined, 3, axis=2)
o = scaled_dot_product_attention_simple(q, k, v, attention_bias)
partial_y = tf.nn.conv1d(o, wo_split[h], 1, "SAME")
pdn, dwqkvh, dwoh = tf.gradients(
ys=[partial_y],
xs=[n, wqkv_split[h], wo_split[h]],
grad_ys=[dy])
dn += pdn
dwqkvs.append(dwqkvh)
dwos.append(dwoh)
deps = [dn, dwqkvh, dwoh]
dwqkv = tf.stack(dwqkvs)
dwo = tf.stack(dwos)
with tf.control_dependencies(deps):
dx, dnorm_scale, dnorm_bias = tf.gradients(
ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn])
return (dx, dwqkv, dwo, tf.zeros_like(attention_bias), dnorm_scale,
dnorm_bias)
@function.Defun(
grad_func=grad_fn, compiled=True, separate_compiled_gradients=True)
def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
return forward_internal(x, wqkv, wo, attention_bias, norm_scale,
norm_bias)
_function_cache[key] = forward_fn
if bias is not None:
bias = tf.squeeze(bias, 1)
with tf.variable_scope(name, default_name="multihead_attention", values=[x]):
# TODO(noam): it would be nice to save memory by casting x to float16
# here, but this causes problems with the gradients. Figure out if there
# is a way to leave the gradients as float32.
if test_vars is not None:
wqkv, wo, norm_scale, norm_bias = list(test_vars)
else:
wqkv = tf.get_variable(
"wqkv", [num_heads, 1, io_size, 3 * head_size],
initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
wo = tf.get_variable(
"wo", [num_heads, 1, head_size, io_size],
initializer=tf.random_normal_initializer(
stddev=(head_size * num_heads)**-0.5))
norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias)
y.set_shape(x.get_shape())
return y
multihead_attention_sparse_dot_prod = functools.partial(
multihead_attention, attention_type=sparse_dot_product_attention)
multihead_attention_sparse_truncated = functools.partial(
multihead_attention, attention_type=sparse_dot_product_attention_truncated)
================================================
FILE: tensor2tensor/layers/common_attention_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for common attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from absl.testing import parameterized
import kfac
import numpy as np
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import contrib
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
tfe = contrib.tfe()
# from tensorflow.contrib.eager.python import tfe as tfe
tf.enable_eager_execution()
class CommonAttentionTest(parameterized.TestCase, tf.test.TestCase):
@test_utils.run_in_graph_and_eager_modes()
def testAttentionBiasLocal(self):
length = 5
bias = common_attention.attention_bias_local(length, 0, 0)
# For length = 5
# [[[[-0.e+00 -1.e+09 -1.e+09 -1.e+09 -1.e+09]
# [-1.e+09 -0.e+00 -1.e+09 -1.e+09 -1.e+09]
# [-1.e+09 -1.e+09 -0.e+00 -1.e+09 -1.e+09]
# [-1.e+09 -1.e+09 -1.e+09 -0.e+00 -1.e+09]
# [-1.e+09 -1.e+09 -1.e+09 -1.e+09 -0.e+00]]]]
res = self.evaluate(bias)
expected_res = -1e9 * np.ones((length, length)) - -1e9 * np.identity(length)
expected_res = np.reshape(expected_res, (1, 1, length, length))
self.assertAllClose(res, expected_res)
@test_utils.run_in_graph_and_eager_modes()
def testAddPositionalEmbedding(self):
x = np.random.rand(5, 3, 12)
y = common_attention.add_positional_embedding(
tf.constant(x, dtype=tf.float32),
max_length=4,
name="pos_embedding")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 3, 12))
@parameterized.named_parameters(
("hard_top_k", 0.0),
("sampled_top_k_default", 1.0),
("sampled_top_k_2", 2.0),
)
@test_utils.run_in_graph_and_eager_modes()
def testHardenAttentionWeights(self, gumbel_noise_weight):
x = np.random.rand(5, 3, 12)
y = common_attention.harden_attention_weights(
tf.nn.softmax(tf.constant(x, dtype=tf.float32)), 3, gumbel_noise_weight)
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 3, 12))
@parameterized.named_parameters(
("hard_top_k", -0.5),
("sampled_top_k", 0.5),
)
@test_utils.run_in_graph_and_eager_modes()
def testHardenAttentionAllZeros(self, gumbel_noise_weight):
"""Check if the hardening code does not divide by zero for all zeros."""
x = np.zeros((5, 3, 12), dtype=np.float32)
y = common_attention.harden_attention_weights(
tf.constant(x, dtype=tf.float32), 3, gumbel_noise_weight)
res = self.evaluate(y)
if gumbel_noise_weight <= 0.0:
self.assertAllClose(res, x)
@parameterized.parameters(
{"input_shape": (5, 3, 12)},
{"input_shape": (5, 5, 5, 12)},
{"input_shape": (5, 3, 3, 3, 12)},
)
@test_utils.run_in_graph_and_eager_modes()
def testAddPositionalEmbeddingNd(self, input_shape):
x = np.random.rand(*input_shape)
y = common_attention.add_positional_embedding_nd(
tf.constant(x, dtype=tf.float32),
max_length=5,
name="pos_embedding")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, input_shape)
@test_utils.run_in_graph_and_eager_modes()
def testAddTimingSignalsGivenPositions(self):
x_positions = tf.expand_dims(
tf.constant([0, 1, 2, 3], dtype=tf.float32), axis=0)
y_positions = tf.expand_dims(
tf.constant([4, 5, 6, 7], dtype=tf.float32), axis=0)
x = tf.zeros([1, 4, 8], dtype=tf.float32)
self.assertAllClose(
common_attention.add_timing_signals_given_positions(
x, [x_positions, y_positions]),
tf.constant([[
[
math.sin(0),
math.sin(0 * 1e-4),
math.cos(0),
math.cos(0 * 1e-4),
math.sin(4),
math.sin(4 * 1e-4),
math.cos(4),
math.cos(4 * 1e-4)
],
[
math.sin(1),
math.sin(1 * 1e-4),
math.cos(1),
math.cos(1 * 1e-4),
math.sin(5),
math.sin(5 * 1e-4),
math.cos(5),
math.cos(5 * 1e-4)
],
[
math.sin(2),
math.sin(2 * 1e-4),
math.cos(2),
math.cos(2 * 1e-4),
math.sin(6),
math.sin(6 * 1e-4),
math.cos(6),
math.cos(6 * 1e-4)
],
[
math.sin(3),
math.sin(3 * 1e-4),
math.cos(3),
math.cos(3 * 1e-4),
math.sin(7),
math.sin(7 * 1e-4),
math.cos(7),
math.cos(7 * 1e-4)
],
]]))
@test_utils.run_in_graph_and_eager_modes()
def testAddTimingSignalsGivenPositionsEquivalent(self):
x = tf.zeros([1, 10, 128], dtype=tf.float32)
positions = tf.expand_dims(tf.range(0, 10, dtype=tf.float32), axis=0)
# The method add_timing_signal_1d_given_position could be replaced by
# add_timing_signals_given_positions:
tf.assert_equal(
common_attention.add_timing_signal_1d_given_position(x, positions),
common_attention.add_timing_signals_given_positions(x, [positions]))
@test_utils.run_in_graph_and_eager_modes()
def testDotProductAttention(self):
x = np.random.rand(5, 7, 12, 32)
y = np.random.rand(5, 7, 12, 32)
a = common_attention.dot_product_attention(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32), None)
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 7, 12, 32))
@parameterized.parameters(
([3, 10, 64], 4),
([3, 10, 20, 64], 2),
([3, 10, 20, 30, 64], 4),
)
def testSplitHeadsND(self, shape, num_heads):
t = tf.zeros(shape)
h = common_attention.split_heads_nd(t, num_heads)
res = self.evaluate(h)
self.assertEqual(
res.shape,
tuple(shape[:1] + [num_heads] + shape[1:-1] + [shape[-1] // num_heads]))
@parameterized.parameters(
([3, 4, 10, 64],),
([3, 2, 10, 20, 64],),
([3, 4, 10, 20, 30, 64],),
)
def testCombineHeadsND(self, shape):
t = tf.zeros(shape)
h = common_attention.combine_heads_nd(t)
res = self.evaluate(h)
self.assertEqual(res.shape,
tuple(shape[:1] + shape[2:-1] + [shape[-1] * shape[1]]))
@parameterized.parameters(
([3, 4, 10, 64], (5,), (10,)),
([3, 4, 10, 10, 64], (5, 5), (5, 5)),
([3, 4, 10, 10, 10, 64], (5, 5, 5), (5, 5, 5)),
)
def testShapeMaskedLocalAttentionND(self, shape, query_shape, memory_flange):
q = k = v = tf.reshape(tf.range(np.prod(shape), dtype=tf.float32), shape)
val = common_attention.masked_local_attention_nd(q, k, v, query_shape,
memory_flange)
res = self.evaluate(val)
self.assertEqual(res.shape, tuple(shape))
@test_utils.run_in_graph_and_eager_modes()
def testRightShiftBlockwiseND(self):
tensor = tf.convert_to_tensor(np.array([[
[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]],
]], dtype=np.float32))
val = common_attention.right_shift_blockwise_nd(tensor, (2, 2))
res = self.evaluate(val)
expected_val = np.array([[
[[0], [1], [6], [3]],
[[2], [5], [4], [7]],
[[8], [9], [14], [11]],
[[10], [13], [12], [15]],
]], dtype=np.float32)
self.assertAllClose(expected_val, res)
@test_utils.run_in_graph_and_eager_modes()
def testContentMaskedLocalAttentionND(self):
def softmax(arr):
return np.exp(arr) / np.sum(np.exp(arr))
q = k = v = tf.convert_to_tensor(
np.array([[[
[[0.1], [0.1], [0.1], [0.1]],
[[0.1], [1.0], [1.0], [0.1]],
[[0.1], [1.0], [1.0], [0.1]],
[[0.1], [0.1], [0.1], [0.1]],
]]], dtype=np.float32))
attn_weights = np.array([[[[softmax([-1e9, -1e9, -1e9, -1e9, 0.01]),
softmax([-1e9, -1e9, -1e9, 0.01, 0.01]),
softmax([-1e9, -1e9, -1e9, 0.01, 0.01]),
softmax([-1e9, -1e9, -1e9, 0.01, 0.01])
],
[softmax([-1e9, 0.01, 0.01, -1e9, 0.01]),
softmax([0.1, 0.1, 0.1, 0.1, 1.0]),
softmax([0.1, 0.1, 0.1, 1.0, 1.0]),
softmax([0.01, 0.01, -1e9, 0.1, 0.01])
],
[softmax([-1e9, 0.01, 0.1, -1e9, 0.01]),
softmax([0.1, 1.0, 1.0, 0.1, 1.0]),
softmax([1.0, 1.0, 0.1, 1.0, 1.0]),
softmax([0.1, 0.01, -1e9, 0.1, 0.01])
],
[softmax([-1e9, 0.01, 0.1, -1e9, 0.01]),
softmax([0.01, 0.1, 0.1, 0.01, 0.01]),
softmax([0.1, 0.1, 0.01, 0.01, 0.01]),
softmax([0.1, 0.01, -1e9, 0.01, 0.01])
]]]])
blocked_v = np.array([[[[[0, 0, 0, 0, 0.1],
[0, 0, 0, 0.1, 0.1],
[0, 0, 0, 0.1, 0.1],
[0, 0, 0, 0.1, 0.1]],
[[0, 0.1, 0.1, 0, 0.1],
[0.1, 0.1, 0.1, 0.1, 1],
[0.1, 0.1, 0.1, 1, 1],
[0.1, 0.1, 0, 1, 0.1]],
[[0, 0.1, 1, 0, 0.1],
[0.1, 1, 1, 0.1, 1],
[1, 1, 0.1, 1, 1],
[1, 0.1, 0, 1, 0.1]],
[[0, 0.1, 1, 0, 0.1],
[0.1, 1, 1, 0.1, 0.1],
[1, 1, 0.1, 0.1, 0.1],
[1, 0.1, 0, 0.1, 0.1]]]]])
expected_val = np.expand_dims(
np.sum(attn_weights * blocked_v, axis=4), axis=-1)
val = common_attention.masked_local_attention_nd(q, k, v, (1, 1), (1, 1))
res = self.evaluate(val)
self.assertAllClose(expected_val, res)
@test_utils.run_in_graph_and_eager_modes()
def testSelectBlockForDecodeStep(self):
tensor = tf.reshape(
tf.range(2 * 6 * 6 * 4, dtype=tf.float32), [2, 6, 6, 4, 1])
block = common_attention.select_block_for_decode_step(tensor, 20, (2, 2))
expected_tensor = tensor[:, 0:1, 5:6, :, :]
expected_value = self.evaluate(expected_tensor)
res = self.evaluate(block)
self.assertAllClose(expected_value, res)
@parameterized.parameters(
((2, 6, 4, 10),),
((2, 6, 6, 4, 10),),
((2, 6, 6, 6, 4, 10),),
)
def testFlattenBlocksND(self, shape):
tensor = tf.zeros(shape, dtype=tf.float32)
value, _ = common_attention.flatten_blocks_nd(tensor)
res = self.evaluate(value)
self.assertAllClose(res.shape,
(shape[0], np.prod(shape[1:-2]), shape[-2], shape[-1]))
@parameterized.parameters(
((5,),),
((5, 10),),
((5, 10, 15),),
)
def testUnflattenBlocksND(self, blocks_per_dim):
tensor = tf.zeros([2, np.prod(blocks_per_dim), 6, 10])
value = common_attention.unflatten_blocks_nd(tensor, blocks_per_dim)
res = self.evaluate(value)
self.assertAllClose(res.shape, (2,) + blocks_per_dim + (6, 10))
@test_utils.run_in_graph_and_eager_modes()
def testBreakIntoMemoryBlocksND(self):
tensor = tf.convert_to_tensor(
np.array([[
[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]],
]]))
value = common_attention.break_into_memory_blocks_nd(tensor,
(2, 2),
(2, 2),
masked=True)
res = self.evaluate(value)
expected_value = np.array([[
[
[
[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0],
[0], [0], [0], [0], [1], [2], [5], [6], [3], [4], [7], [8]
],
[
[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0],
[1], [2], [5], [6], [3], [4], [7], [8], [0], [0], [0], [0]
]
],
[
[
[0], [0], [0], [0], [1], [2], [5], [6], [3], [4], [7], [8], [0],
[0], [0], [0], [9], [10], [13], [14], [11], [12], [15], [16]
],
[
[1], [2], [5], [6], [3], [4], [7], [8], [0], [0], [0], [0], [9],
[10], [13], [14], [11], [12], [15], [16], [0], [0], [0], [0]
]
]]])
self.assertAllClose(expected_value, res)
@test_utils.run_in_graph_and_eager_modes()
def testBreakIntoBlocksND(self):
tensor = tf.convert_to_tensor(
np.array([[
[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]],
]]))
value = common_attention.break_into_blocks_nd(tensor, (2, 2))
res = self.evaluate(value)
expected_value = np.array([[
[[[1], [2], [5], [6]], [[3], [4], [7], [8]]],
[[[9], [10], [13], [14]], [[11], [12], [15], [16]]]
]])
self.assertAllClose(expected_value, res)
@test_utils.run_in_graph_and_eager_modes()
def testPutBackBlocksND(self):
tensor = tf.convert_to_tensor(
np.array([[
[[[1], [2], [5], [6]], [[3], [4], [7], [8]]],
[[[9], [10], [13], [14]], [[11], [12], [15], [16]]]
]]))
value = common_attention.put_back_blocks_nd(tensor, (2, 2))
res = self.evaluate(value)
expected_value = np.array([[
[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]],
]])
self.assertAllClose(expected_value, res)
@parameterized.parameters(
((2, 100, 5), (7,), (2, 105, 5)),
((2, 100, 100, 5), (5, 7), (2, 100, 105, 5)),
((2, 100, 100, 100, 5), (10, 20, 30), (2, 100, 100, 120, 5))
)
def testPadToMultipleND(self, tensor_shape, block_shape, expected_shape):
tensor = tf.zeros(tensor_shape)
value = common_attention.pad_to_multiple_nd(tensor, block_shape)
res = self.evaluate(value)
self.assertAllClose(res.shape, expected_shape)
@test_utils.run_in_graph_and_eager_modes()
def testCausalAttentionBiasND(self):
bias = common_attention.causal_attention_bias_nd((2, 2), (2, 2))
res = self.evaluate(bias)
expected_val = np.array([[[
[0] * 17 + [-1e9] * 7,
[0] * 18 + [-1e9] * 6,
[0] * 19 + [-1e9] * 5,
[0] * 20 + [-1e9] * 4,
]]])
self.assertAllClose(expected_val, res)
@parameterized.parameters(
((1, 64, 10), (80,), (80,)),
((1, 64, 64, 10), (8, 8), (16, 16)),
((1, 5, 64, 64, 10), (1, 8, 8), (1, 8, 8))
)
def testMultiheadAttentionND(self, tensor_shape, query_shape, memory_flange):
query_antecedent = tf.zeros(tensor_shape)
value = common_attention.multihead_attention_nd(
query_antecedent=query_antecedent,
memory_antecedent=None,
total_key_depth=256,
total_value_depth=256,
output_depth=256,
num_heads=4,
query_shape=query_shape,
memory_flange=memory_flange,
masked=True)
res = self.evaluate(value)
self.assertAllClose(res.shape, tensor_shape[:-1] + (256,))
@parameterized.parameters(
(15, (5,), (100,), (15,)),
(10, (2, 2), (4, 4), (3, 0)),
(25, (2, 2, 3), (10, 10, 12), (0, 0, 7))
)
def testDecodeStepToIndex(self, decode_step, query_shape, tensor_shape,
expected_index):
res = common_attention.decode_step_to_index(decode_step, query_shape,
tensor_shape)
self.assertAllClose(res, expected_index)
@test_utils.run_in_graph_and_eager_modes()
def testGetItemAtDecodeStep(self):
tensor = tf.reshape(tf.range(25 * 25 * 4), [1, 4, 25, 25, 1])
value = common_attention.get_item_at_decode_step(tensor, 100, (2, 5, 5))
res = self.evaluate(value)
expected_value = np.array([[[[[10]]]]])
self.assertAllClose(expected_value, res)
@test_utils.run_in_graph_and_eager_modes()
def testPutItemAtDecodeStep(self):
tensor = tf.zeros([1, 1, 10, 10, 1])
item = tf.ones([1, 1, 1, 1, 1])
value = common_attention.put_item_in_decode_step(tensor, item, 32, (2, 2))
res = self.evaluate(value)
expected_val = np.zeros([1, 1, 10, 10, 1])
expected_val[0, 0, 2, 6, 0] = 1
self.assertAllClose(expected_val, res)
@parameterized.named_parameters(
("", 1, 1, 8, 4, 1, 2),
("dynamic_batch", None, 1, 8, 4, 1, 2),
("batches", 4, 3, 8, 4, 1, 2),
("depth_v", 1, 1, 8, 4, 3, 2),
("block_length", 1, 1, 8, 4, 1, 4),
)
def testMaskedWithinBlockLocalAttention1D(self, batch, heads, length,
depth_k, depth_v, block_length):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, depth_k])
k = tf.random_normal([batch, heads, length, depth_k])
v = tf.random_normal([batch, heads, length, depth_v])
output = common_attention.masked_within_block_local_attention_1d(
q, k, v, block_length=block_length)
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, depth_v))
@parameterized.named_parameters(
("", 1, 1, 8, 4, 1, 2),
("dynamic_batch", None, 1, 8, 4, 1, 2),
("batches", 4, 3, 8, 4, 1, 2),
("depth_v", 1, 1, 8, 4, 3, 2),
("block_length", 1, 1, 8, 4, 1, 4),
)
def testMaskedLocalAttention1D(self, batch, heads, length, depth_k, depth_v,
block_length):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, depth_k])
k = tf.random_normal([batch, heads, length, depth_k])
v = tf.random_normal([batch, heads, length, depth_v])
output = common_attention.masked_local_attention_1d(
q, k, v, block_length=block_length)
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, depth_v))
@parameterized.named_parameters(
("", 1, 1, 8, 4, 4, (2, 2)),
("dynamic_batch", None, 1, 8, 4, 4, (2, 2)),
("batches", 3, 2, 8, 4, 4, (2, 2)),
# TODO(trandustin): Extend function to enable depth_k != depth_v.
# ("depth_v", 1, 1, 8, 4, 1, (2, 2)),
("query_shape", 1, 1, 8, 4, 4, (4, 4)),
)
def testMaskedLocalAttention2D(self, batch, heads, length, depth_k, depth_v,
query_shape):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, length, depth_k])
k = tf.random_normal([batch, heads, length, length, depth_k])
v = tf.random_normal([batch, heads, length, length, depth_v])
output = common_attention.masked_local_attention_2d(
q,
k,
v,
query_shape=query_shape,
memory_flange=(2, 2))
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, length, depth_v))
@parameterized.named_parameters(
("matching_block_length", 3, 4, 25, 16, 16, 5),
("unmatching_block_length", 3, 4, 25, 16, 16, 4),
("dynamic_batch", None, 4, 25, 16, 16, 5),
("different_depth_v", 3, 4, 25, 16, 17, 5),
)
def testLocalUnmaskedAttention1D(self, batch, heads, length,
depth_k, depth_v, block_length):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, depth_k])
k = tf.random_normal([batch, heads, length, depth_k])
v = tf.random_normal([batch, heads, length, depth_v])
output = common_attention.local_attention_1d(
q, k, v, block_length=block_length, filter_width=3)
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, depth_v))
@parameterized.named_parameters(
("matching_block_length", 3, 4, 25, 16, 16, (4, 4)),
("unmatching_block_length", 3, 4, 25, 16, 16, (5, 5)),
("dynamic_batch", None, 4, 25, 16, 16, (4, 4)),
# TODO(trandustin): Extend function to enable depth_k != depth_v.
# ("different_depth_v", 3, 4, 25, 16, 17, (4, 4)),
)
def testLocalUnmaskedAttention2D(self, batch, heads, length,
depth_k, depth_v, query_shape):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, length, depth_k])
k = tf.random_normal([batch, heads, length, length, depth_k])
v = tf.random_normal([batch, heads, length, length, depth_v])
output = common_attention.local_attention_2d(
q,
k,
v,
query_shape=query_shape,
memory_flange=(3, 3))
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, length, depth_v))
@test_utils.run_in_graph_mode_only()
def testMultiheadSelfAttentionMemoryEfficient(self):
num_heads = 4
io_size = 16
batch = 2
length = 7
head_size = 5
x = np.random.rand(batch, length, io_size)
dy = np.random.rand(batch, length, io_size)
with self.test_session() as session:
x = tf.to_float(x)
dy = tf.to_float(dy)
bias = common_attention.attention_bias_lower_triangle(length)
wqkv = tf.get_variable(
"wqkv", [num_heads, 1, io_size, 3 * head_size],
initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
wo = tf.get_variable(
"wo", [num_heads, 1, head_size, io_size],
initializer=tf.random_normal_initializer(
stddev=(head_size * num_heads)**-0.5))
norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
y = common_attention.multihead_self_attention_memory_efficient(
x, bias, num_heads, head_size=head_size, forget=False,
test_vars=(wqkv, wo, norm_scale, norm_bias))
y_forget = common_attention.multihead_self_attention_memory_efficient(
x, bias, num_heads, head_size=head_size, forget=True,
test_vars=(wqkv, wo, norm_scale, norm_bias))
dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
session.run(tf.global_variables_initializer())
(y, y_forget,
dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run(
[y, y_forget,
dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f])
self.assertAllClose(y, y_forget)
self.assertAllClose(dwo, dwo_f)
self.assertAllClose(dwqkv, dwqkv_f)
self.assertAllClose(dnorm_scale, dnorm_scale_f)
self.assertAllClose(dnorm_bias, dnorm_bias_f)
self.assertAllClose(dx, dx_f)
@test_utils.run_in_graph_and_eager_modes()
def test2dGatherAndScatterInvertibility(self):
"""2d gather and scatter invertibility test."""
batch_size = 2
num_heads = 2
height = 4
width = 6
depth = 8
query_shape = (2, 3)
x = np.random.rand(batch_size, num_heads, height, width, depth)
x_indices = common_attention.gather_indices_2d(
x, query_shape, query_shape)
gathered_x = common_attention.gather_blocks_2d(x, x_indices)
x_shape = tf.constant([batch_size, num_heads, height, width, depth])
scattered_x = common_attention.scatter_blocks_2d(
gathered_x, x_indices, x_shape)
res = self.evaluate(scattered_x)
self.assertAllClose(x, res)
@test_utils.run_in_graph_and_eager_modes()
def test2dBlockRasterScanMask(self):
"""Testing the 2d block raster scan mask."""
query_shape = (2, 3)
memory_flange = (2, 1)
mask = common_attention.make_2d_block_raster_mask(
query_shape, memory_flange)
res = self.evaluate(mask)
correct_mask = np.array(
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0,
1.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
1.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 0.0, 0.0, 0.0, 1.0]])
self.assertAllClose(correct_mask, res)
@test_utils.run_in_graph_and_eager_modes()
def test2dGather(self):
"""Testing 2d index gather and block gather functions."""
batch_size = 2
num_heads = 2
height = 4
width = 6
depth = 8
query_shape = (2, 3)
x = np.random.rand(batch_size, num_heads, height, width, depth)
y = np.reshape(x, (batch_size, num_heads, -1, depth))
correct_indices = [[0, 1, 2, 6, 7, 8],
[3, 4, 5, 9, 10, 11],
[12, 13, 14, 18, 19, 20],
[15, 16, 17, 21, 22, 23]]
correct_gathered_x = [[[y[0, 0, correct_indices[0]],
y[0, 0, correct_indices[1]],
y[0, 0, correct_indices[2]],
y[0, 0, correct_indices[3]]],
[y[0, 1, correct_indices[0]],
y[0, 1, correct_indices[1]],
y[0, 1, correct_indices[2]],
y[0, 1, correct_indices[3]]]],
[[y[1, 0, correct_indices[0]],
y[1, 0, correct_indices[1]],
y[1, 0, correct_indices[2]],
y[1, 0, correct_indices[3]]],
[y[1, 1, correct_indices[0]],
y[1, 1, correct_indices[1]],
y[1, 1, correct_indices[2]],
y[1, 1, correct_indices[3]]]]]
x_indices = common_attention.gather_indices_2d(
x, query_shape, query_shape)
gathered_x = common_attention.gather_blocks_2d(x, x_indices)
x_indices, gathered_x = self.evaluate([x_indices, gathered_x])
self.assertAllEqual(correct_indices, x_indices)
self.assertAllClose(correct_gathered_x, gathered_x)
@test_utils.run_in_graph_and_eager_modes()
def testGetMemoryRegion(self):
"""Testing the function that gathers the flanged memory region."""
np.set_printoptions(threshold=np.inf)
batch_size = 2
num_heads = 2
height = 4
width = 6
depth = 3
query_shape = (2, 3)
memory_flange = (1, 1)
x = np.random.rand(batch_size, num_heads, height, width, depth)
y = np.reshape(x, (batch_size, num_heads, -1, depth))
zeros = np.zeros((depth), dtype=np.float32)
five_zeros = np.array([zeros]*5)
seven_zeros = np.array([zeros]*7)
two_zeros = np.array([zeros]*2)
zeros = np.array([zeros])
correct_x_flange = [[[seven_zeros,
np.concatenate((five_zeros, y[0, 0, [2, 8]]),
axis=0),
np.concatenate((zeros, y[0, 0, [6, 7, 8, 9]],
two_zeros), axis=0),
np.concatenate((y[0, 0, [8, 9, 10, 11]], zeros,
y[0, 0, [14, 20]]), axis=0)],
[seven_zeros,
np.concatenate((five_zeros, y[0, 1, [2, 8]]),
axis=0),
np.concatenate((zeros, y[0, 1, [6, 7, 8, 9]],
two_zeros), axis=0),
np.concatenate((y[0, 1, [8, 9, 10, 11]], zeros,
y[0, 1, [14, 20]]), axis=0)]],
[[seven_zeros,
np.concatenate((five_zeros, y[1, 0, [2, 8]]),
axis=0),
np.concatenate((zeros, y[1, 0, [6, 7, 8, 9]],
two_zeros), axis=0),
np.concatenate((y[1, 0, [8, 9, 10, 11]], zeros,
y[1, 0, [14, 20]]), axis=0)],
[seven_zeros,
np.concatenate((five_zeros, y[1, 1, [2, 8]]),
axis=0),
np.concatenate((zeros, y[1, 1, [6, 7, 8, 9]],
two_zeros), axis=0),
np.concatenate((y[1, 1, [8, 9, 10, 11]], zeros,
y[1, 1, [14, 20]]), axis=0)]]]
correct_x_flange = np.array(correct_x_flange)
correct_x_center = [[[y[0, 0, [0, 1, 2, 6, 7, 8]],
y[0, 0, [3, 4, 5, 9, 10, 11]],
y[0, 0, [12, 13, 14, 18, 19, 20]],
y[0, 0, [15, 16, 17, 21, 22, 23]]],
[y[0, 1, [0, 1, 2, 6, 7, 8]],
y[0, 1, [3, 4, 5, 9, 10, 11]],
y[0, 1, [12, 13, 14, 18, 19, 20]],
y[0, 1, [15, 16, 17, 21, 22, 23]]]],
[[y[1, 0, [0, 1, 2, 6, 7, 8]],
y[1, 0, [3, 4, 5, 9, 10, 11]],
y[1, 0, [12, 13, 14, 18, 19, 20]],
y[1, 0, [15, 16, 17, 21, 22, 23]]],
[y[1, 1, [0, 1, 2, 6, 7, 8]],
y[1, 1, [3, 4, 5, 9, 10, 11]],
y[1, 1, [12, 13, 14, 18, 19, 20]],
y[1, 1, [15, 16, 17, 21, 22, 23]]]]]
correct_x_center = np.array(correct_x_center)
x_indices = common_attention.gather_indices_2d(
x, query_shape, query_shape)
x_flange, x_center = common_attention.get_memory_region(
tf.constant(x, dtype=tf.float32),
query_shape,
memory_flange,
x_indices)
[x_flange, x_center] = self.evaluate([x_flange, x_center])
self.assertAllClose(correct_x_flange, x_flange)
self.assertAllClose(correct_x_center, x_center)
@test_utils.run_in_graph_and_eager_modes()
def testGetShiftedCenterBlocks(self):
"""Testing the function that gathers the flanged memory region."""
np.set_printoptions(threshold=np.inf)
batch_size = 2
num_heads = 2
height = 4
width = 6
depth = 3
query_shape = (2, 3)
x = np.random.rand(batch_size, num_heads, height, width, depth)
y = np.reshape(x, (batch_size, num_heads, -1, depth))
zeros = np.zeros((depth), dtype=np.float32)
zeros = np.array([zeros])
correct_gathered_x = [[[np.concatenate((zeros, y[0, 0, [0, 1, 2, 6, 7]]),
axis=0),
np.concatenate((zeros, y[0, 0, [3, 4, 5, 9, 10]]),
axis=0),
np.concatenate((zeros,
y[0, 0, [12, 13, 14, 18, 19]]),
axis=0),
np.concatenate((zeros,
y[0, 0, [15, 16, 17, 21, 22]]),
axis=0)],
[np.concatenate((zeros, y[0, 1, [0, 1, 2, 6, 7]]),
axis=0),
np.concatenate((zeros, y[0, 1, [3, 4, 5, 9, 10]]),
axis=0),
np.concatenate((zeros,
y[0, 1, [12, 13, 14, 18, 19]]),
axis=0),
np.concatenate((zeros,
y[0, 1, [15, 16, 17, 21, 22]]),
axis=0)]],
[[np.concatenate((zeros, y[1, 0, [0, 1, 2, 6, 7]]),
axis=0),
np.concatenate((zeros, y[1, 0, [3, 4, 5, 9, 10]]),
axis=0),
np.concatenate((zeros,
y[1, 0, [12, 13, 14, 18, 19]]),
axis=0),
np.concatenate((zeros,
y[1, 0, [15, 16, 17, 21, 22]]),
axis=0)],
[np.concatenate((zeros, y[1, 1, [0, 1, 2, 6, 7]]),
axis=0),
np.concatenate((zeros, y[1, 1, [3, 4, 5, 9, 10]]),
axis=0),
np.concatenate((zeros,
y[1, 1, [12, 13, 14, 18, 19]]),
axis=0),
np.concatenate((zeros,
y[1, 1, [15, 16, 17, 21, 22]]),
axis=0)]]]
correct_gathered_x = np.array(correct_gathered_x)
x_indices = common_attention.gather_indices_2d(
x, query_shape, query_shape)
gathered_x = common_attention.get_shifted_center_blocks(
tf.constant(x, dtype=tf.float32),
x_indices)
x_indices, gathered_x = self.evaluate([x_indices, gathered_x])
self.assertAllClose(correct_gathered_x, gathered_x)
@test_utils.run_in_graph_and_eager_modes()
def testDotProductAttentionRelative(self):
x = np.random.rand(5, 7, 12, 32)
y = np.random.rand(5, 7, 12, 32)
a = common_attention.dot_product_attention_relative(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=3)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 7, 12, 32))
@test_utils.run_in_graph_and_eager_modes()
def testRelativeAttentionV2(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 16, 7)
y = np.random.rand(5, 4, 16, 7)
max_relative_position = 3
a = common_attention.dot_product_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=False)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 16, 7))
@test_utils.run_in_graph_and_eager_modes()
def testRelativeAttentionV2SharedRel(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 16, 7)
y = np.random.rand(5, 4, 16, 7)
max_relative_position = 3
a = common_attention.dot_product_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=True)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 16, 7))
@test_utils.run_in_graph_and_eager_modes()
def testRelativeAttentionV2MaxRelativeLargerThanLength(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 3, 7)
y = np.random.rand(5, 4, 3, 7)
max_relative_position = 16
a = common_attention.dot_product_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=False)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 3, 7))
@test_utils.run_in_graph_and_eager_modes()
def testDotProductUnMaskedAttentionRelativeV2(self):
x = np.random.rand(5, 7, 12, 32)
y = np.random.rand(5, 7, 12, 32)
a = common_attention.dot_product_unmasked_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
35)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 7, 12, 32))
@tfe.run_test_in_graph_and_eager_modes()
def testExtractblocks(self):
batch_size = 1
num_heads = 3
height = 6
width = 10
depth = 15
block_h = 3
block_w = 2
t = np.random.rand(batch_size * num_heads, height, width, depth)
a = common_attention._extract_blocks(t, block_h, block_w)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (batch_size * num_heads, height//block_h,
width//block_w, block_h, block_w, depth))
# also check if the content is right
out = np.zeros((batch_size*num_heads, height//block_h,
width//block_w, block_h, block_w, depth))
for b in range(batch_size*num_heads):
for x in range(height//block_h):
for y in range(width//block_w):
for v in range(block_h):
for w in range(block_w):
out[b, x, y, v, w] = t[b, block_h*x+v, block_w*y+w]
self.assertAllClose(res, out)
def python_get_2d_local_memory(self, t, batch_size, num_heads, height, width,
num_h_blocks, num_w_blocks, query_shape,
memory_flange, depth):
# also check if the content is right
out = np.zeros((batch_size, num_heads, height//query_shape[0],
width//query_shape[1], query_shape[0]+2*memory_flange[0],
query_shape[1]+2*memory_flange[1], depth))
memory_height = query_shape[0]+2*memory_flange[0]
memory_width = query_shape[1]+2*memory_flange[1]
t_padded = np.pad(t, ((0, 0), (0, 0), (memory_flange[0], memory_flange[0]),
(memory_flange[1], memory_flange[1]), (0, 0)),
"constant",
constant_values=((0, 0), (0, 0), (0, 0), (0, 0), (0, 0)))
for b in range(batch_size):
for h in range(num_heads):
for x in range(num_h_blocks):
for y in range(num_w_blocks):
for v in range(memory_height):
for w in range(memory_width):
memory_h_start = x*query_shape[0]
memory_w_start = y*query_shape[1]
memory_h_index = memory_h_start + v
memory_w_index = memory_w_start + w
out[b, h, x, y, v, w] = t_padded[b, h, memory_h_index,
memory_w_index]
return out
@tfe.run_test_in_graph_and_eager_modes()
def testGet2dLocalMemory(self):
batch_size = 3
num_heads = 3
height = 6
width = 6
depth = 15
num_h_blocks = 3
num_w_blocks = 3
memory_flange = [1, 1]
query_shape = [2, 2]
t = np.random.rand(batch_size, num_heads, height, width, depth)
a = common_attention.get_2d_local_memory_v2(
np.reshape(t, (batch_size*num_heads, height, width, depth)),
query_shape, memory_flange)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (batch_size*num_heads,
num_h_blocks,
num_w_blocks,
query_shape[0]+2*memory_flange[0],
query_shape[1]+2*memory_flange[1], depth))
out = self.python_get_2d_local_memory(t, batch_size, num_heads,
height, width, num_h_blocks,
num_w_blocks, query_shape,
memory_flange, depth)
out = np.reshape(out, (batch_size*num_heads,
num_h_blocks,
num_w_blocks,
query_shape[0]+2*memory_flange[0],
query_shape[1]+2*memory_flange[1], depth))
self.assertAllClose(res, out)
@tfe.run_test_in_graph_and_eager_modes()
def testSplitAlongWidth(self):
batch_size = 1
num_heads = 3
num_outer_h_blocks = 4
num_outer_w_blocks = 8
memory_flange = [2, 2]
num_w_blocks = 3
depth = 15
t = np.random.rand(batch_size*num_heads, num_outer_h_blocks,
num_outer_w_blocks, memory_flange[0], memory_flange[1],
depth)
a = common_attention._split_along_width(t)
# self.evaluate(tf.global_variables_initializer())
res_l, res_r = self.evaluate(a)
# res = self.evaluate(a)
self.assertEqual(res_l.shape, (batch_size*num_heads, num_outer_h_blocks,
num_w_blocks, memory_flange[0],
memory_flange[1], depth))
self.assertEqual(res_r.shape, (batch_size*num_heads, num_outer_h_blocks,
num_w_blocks, memory_flange[0],
memory_flange[1], depth))
# also check if the content is right
out_l = np.zeros((batch_size*num_heads, num_outer_h_blocks, num_w_blocks,
memory_flange[0], memory_flange[1], depth))
out_r = np.zeros((batch_size*num_heads, num_outer_h_blocks, num_w_blocks,
memory_flange[0], memory_flange[1], depth))
block_h = memory_flange[0]
block_w = memory_flange[1]
for b in range(batch_size*num_heads):
for x in range(num_outer_h_blocks):
for y in range(num_w_blocks):
for v in range(block_h):
for w in range(block_w):
# we should compute the index of the position in the
out_l[b, x, y, v, w] = (
t[b, x, 2*y, v, w]
)
out_r[b, x, y, v, w] = (
t[b, x, 2*y+3, v, w]
)
self.assertAllClose(res_l, out_l)
self.assertAllClose(res_r, out_r)
@tfe.run_test_in_graph_and_eager_modes()
def testGetLeftRightBlocks(self):
batch_size = 1
num_heads = 3
num_outer_h_blocks = 6
num_outer_w_blocks = 6
memory_flange = [2, 2]
num_h_blocks = 2
num_w_blocks = 2
depth = 15
t = np.random.rand(batch_size*num_heads, num_outer_h_blocks,
num_outer_w_blocks, memory_flange[0], memory_flange[1],
depth)
a = common_attention._get_left_right_blocks(t)
self.evaluate(tf.global_variables_initializer())
res_l, res_r = self.evaluate(a)
self.assertEqual(res_l.shape, (batch_size*num_heads, num_h_blocks,
num_w_blocks, memory_flange[0]*2,
memory_flange[1], depth))
self.assertEqual(res_r.shape, (batch_size*num_heads, num_h_blocks,
num_w_blocks, memory_flange[0]*2,
memory_flange[1], depth))
# also check if the content is right
block_h = memory_flange[0]*2
block_w = memory_flange[1]
out_l = np.zeros((batch_size*num_heads, num_h_blocks,
num_w_blocks, memory_flange[0]*2, memory_flange[1],
depth))
out_r = np.zeros((batch_size*num_heads, num_h_blocks,
num_w_blocks, memory_flange[0]*2, memory_flange[1],
depth))
block_h = memory_flange[0]*2
block_w = memory_flange[1]
for b in range(batch_size*num_heads):
for x in range(num_h_blocks):
for y in range(num_w_blocks):
for v in range(block_h):
for w in range(block_w):
# we should compute the index of the position in the
outer_block_h_index = (
1 + block_h//memory_flange[0]*x + v//2)
h_index = v%memory_flange[0]
left_outer_w_index = 2*y
right_outer_w_index = 2*y + 3
out_l[b, x, y, v, w] = (
t[b, outer_block_h_index, left_outer_w_index, h_index,
w]
)
out_r[b, x, y, v, w] = (
t[b, outer_block_h_index, right_outer_w_index, h_index,
w]
)
self.assertAllClose(res_l, out_l)
self.assertAllClose(res_r, out_r)
@tfe.run_test_in_graph_and_eager_modes()
def testDotProductUnmaskedAttentionLocal2dTpu(self):
batch_size = 1
num_heads = 3
height = 7
width = 12
depth = 15
num_h_blocks = 4
num_w_blocks = 6
memory_flange = [1, 1]
query_shape = [2, 2]
memory_h = query_shape[0] + 2*memory_flange[0]
memory_w = query_shape[1] + 2*memory_flange[1]
q = np.random.rand(batch_size, num_heads, height, width, depth)
k = np.random.rand(batch_size, num_heads, height, width, depth)
v = np.random.rand(batch_size, num_heads, height, width, depth)
a = common_attention.dot_product_unmasked_attention_local_2d_tpu(
tf.constant(q, dtype=tf.float32),
tf.constant(k, dtype=tf.float32),
tf.constant(v, dtype=tf.float32), None, max_relative_position=None,
query_shape=query_shape, dropout_rate=0.0, image_shapes=None,
name=None, make_image_summary=False, dropout_broadcast_dims=None)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (batch_size, num_heads,
height, width, depth))
# now to check the content too
# first pad q, k, ad v
height_padding = -height % query_shape[0]
width_padding = -width % query_shape[1]
new_height = height + -height % query_shape[0]
new_width = width + -width % query_shape[1]
q = np.pad(q, ((0, 0), (0, 0), (0, height_padding),
(0, width_padding), (0, 0)), "constant",
constant_values=((0, 0), (0, 0), (0, 0), (0, 0), (0, 0)))
k = np.pad(k, ((0, 0), (0, 0), (0, height_padding),
(0, width_padding), (0, 0)), "constant",
constant_values=((0, 0), (0, 0), (0, 0), (0, 0), (0, 0)))
v = np.pad(v, ((0, 0), (0, 0), (0, height_padding),
(0, width_padding), (0, 0)), "constant",
constant_values=((0, 0), (0, 0), (0, 0), (0, 0), (0, 0)))
queries = self.python_get_2d_local_memory(q, batch_size, num_heads,
new_height, new_width,
num_h_blocks, num_w_blocks,
query_shape, [0, 0],
depth)
keys = self.python_get_2d_local_memory(k, batch_size, num_heads,
new_height, new_width, num_h_blocks,
num_w_blocks, query_shape,
memory_flange, depth)
values = self.python_get_2d_local_memory(v, batch_size, num_heads,
new_height, new_width,
num_h_blocks, num_w_blocks,
query_shape,
memory_flange, depth)
logits = np.matmul(
np.reshape(queries, (batch_size, num_heads,
num_h_blocks, num_w_blocks,
query_shape[0]*query_shape[1], depth)),
np.transpose(
np.reshape(keys, (batch_size, num_heads, num_h_blocks, num_w_blocks,
memory_h*memory_w, depth)), (0, 1, 2, 3, 5, 4)))
# now to do a softmax across the logits
att = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
att_output = np.matmul(att, np.reshape(
values, (batch_size, num_heads, num_h_blocks, num_w_blocks,
memory_h*memory_w, depth)))
att_output = np.reshape(att_output,
(batch_size, num_heads, num_h_blocks, num_w_blocks,
query_shape[0], query_shape[1], depth))
# putting the attention results back into the right place
out = np.zeros((batch_size, num_heads, new_height, new_width, depth))
for b in range(batch_size):
for h in range(num_heads):
for x in range(new_height):
for y in range(new_width):
h_block_index = x//query_shape[0]
w_block_index = y//query_shape[1]
inside_h_index = x%query_shape[0]
inside_w_index = y%query_shape[1]
out[b, h, x, y] = (
att_output[b, h, h_block_index, w_block_index, inside_h_index,
inside_w_index])
out = out[:, :, :height, :width, :]
self.assertAllClose(res, out)
@tfe.run_test_in_graph_and_eager_modes()
def testDotProductUnmaskedAttentionLocal2dTpuSimple(self):
batch_size = 1
num_heads = 3
height = 8
width = 12
total_depth = 15
num_h_blocks = 4
num_w_blocks = 6
depth = 5
query_shape = [2, 2]
x = np.random.rand(batch_size, height, width, total_depth)
a = (
common_attention.dot_product_unmasked_attention_local_2d_tpu_simple(
tf.constant(x, dtype=tf.float32),
None, total_depth, total_depth, num_heads,
query_shape=query_shape))
self.evaluate(tf.global_variables_initializer())
res, q, k, v = self.evaluate(a)
self.assertEqual(res.shape, (batch_size, height, width, total_depth))
# reshape q, k, v from batch, heads, height*width to batch, heads,
# num_h_blocks, num_w_blocks, query_shape[0], query_shape[1], depth
resh_shape = (batch_size, num_h_blocks, num_w_blocks,
num_heads, query_shape[0], query_shape[1],
depth)
resh = lambda l: np.reshape(l, resh_shape)
q, k, v = map(resh, [q, k, v])
trans = lambda l: np.transpose(l, (0, 3, 1, 2, 4, 5, 6))
q, k, v = map(trans, [q, k, v])
new_height = height + -height % query_shape[0]
new_width = width + -width % query_shape[1]
(queries, keys, values) = (q, k, v)
logits = np.matmul(
np.reshape(queries, (batch_size, num_heads,
num_h_blocks, num_w_blocks,
query_shape[0]*query_shape[1], depth)),
np.transpose(
np.reshape(keys, (batch_size, num_heads, num_h_blocks, num_w_blocks,
query_shape[0]*query_shape[1], depth)),
(0, 1, 2, 3, 5, 4)))
# now to do a softmax across the logits
att = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
att_output = np.matmul(att, np.reshape(
values, (batch_size, num_heads, num_h_blocks, num_w_blocks,
query_shape[0]*query_shape[1], depth)))
att_output = np.reshape(att_output,
(batch_size, num_heads, num_h_blocks, num_w_blocks,
query_shape[0], query_shape[1], depth))
# putting the attention results back into the right place
out = np.zeros((batch_size, num_heads, new_height, new_width, depth))
for b in range(batch_size):
for h in range(num_heads):
for x in range(new_height):
for y in range(new_width):
h_block_index = x//query_shape[0]
w_block_index = y//query_shape[1]
inside_h_index = x%query_shape[0]
inside_w_index = y%query_shape[1]
out[b, h, x, y] = (
att_output[b, h, h_block_index, w_block_index, inside_h_index,
inside_w_index])
out = np.transpose(out, (0, 2, 3, 1, 4))
out = np.reshape(out, (batch_size, new_height, new_width, total_depth))
out = out[:, :height, :width, :]
self.assertAllClose(res, out)
def python_relative_att(self, q, k, v, batch, num_heads, height, width,
depth, height_key_relative_embeddings,
width_key_relative_embeddings,
heads_share_relative_embedding):
"""Relative attention computation in numpy.
For query index (i,j) and key index (l, m) the logit is
q_i k_j^T + q_i rh_{l-i}^T + q_i rw_{m-j}^T, where rh and ry are the set of
relative embeddings in height and width spatial dimensions, respectively.
Args:
q: [batch, heads, height, width, depth] tensor
k: [batch, heads, height, width, depth] tensor
v: [batch, heads, height, width, depth] tensor
batch: int scalar
num_heads: int scalar
height: int scalar
width: int scalar
depth: int scalar
height_key_relative_embeddings: a tensor of relative embeddings
width_key_relative_embeddings: a tensor of relative embeddings
heads_share_relative_embedding: a boolean
Returns:
att_output: A tensor
"""
logits = np.zeros((batch, num_heads, height*width, height*width))
for b in range(batch):
for h in range(num_heads):
for i in range(height*width):
q_col = i%width
q_row = int((i-q_col)/width)
for j in range(height*width):
k_col = j%width
k_row = int((j-k_col)/width)
logit = np.dot(q[b][h][q_row][q_col], k[b][h][k_row][k_col])
width_rel_dist = k_col - q_col
width_rel_index = width-1 + width_rel_dist
if heads_share_relative_embedding:
width_rel_logit = (
np.dot(q[b][h][q_row][q_col],
width_key_relative_embeddings[width_rel_index]))
else:
width_rel_logit = (
np.dot(q[b][h][q_row][q_col],
width_key_relative_embeddings[h][width_rel_index]))
height_rel_dist = k_row - q_row
height_rel_index = height-1 + height_rel_dist
if heads_share_relative_embedding:
height_rel_logit = (
np.dot(q[b][h][q_row][q_col],
height_key_relative_embeddings[height_rel_index]))
else:
height_rel_logit = (
np.dot(q[b][h][q_row][q_col],
height_key_relative_embeddings[h][height_rel_index]))
logits[b, h, i, j] = logit + width_rel_logit + height_rel_logit
# now to do a softmax across the logits
att = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
# comparing the outputs
att_output = np.matmul(att,
np.reshape(v, (
batch, num_heads, height*width, depth)))
att_output = np.reshape(att_output,
(batch, num_heads, height, width, depth))
return att_output
@test_utils.run_in_graph_and_eager_modes()
def testDotProductUnMaskedAttentionRelative2d(self):
batch = 1
height = 3
width = 3
num_heads = 2
max_relative_position = 6
depth = 5
heads_share_relative_embedding = False
q = np.random.rand(batch, num_heads, height, width, depth)
k = np.random.rand(batch, num_heads, height, width, depth)
v = np.random.rand(batch, num_heads, height, width, depth)
a = common_attention.dot_product_unmasked_self_attention_relative_2d(
tf.constant(q, dtype=tf.float32),
tf.constant(k, dtype=tf.float32),
tf.constant(v, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=heads_share_relative_embedding)
self.evaluate(tf.global_variables_initializer())
res, height_key_relative_embeddings, width_key_relative_embeddings = (
self.evaluate(a))
att_output = self.python_relative_att(
q, k, v, batch, num_heads, height, width, depth,
height_key_relative_embeddings, width_key_relative_embeddings,
heads_share_relative_embedding)
self.assertEqual(res.shape, (batch, num_heads, height, width, depth))
self.assertAllClose(res, att_output)
@parameterized.parameters(
(1, 10, 12, 2, 6, 3),
(1, 1, 12, 2, 6, 3),
(2, 10, 1, 2, 6, 3),
(1, 10, 12, 2, 1, 1),
(1, 10, 12, 2, 2, 8),
(4, 10, 12, 2, 12, 10),
)
@test_utils.run_in_graph_and_eager_modes()
def testDotProductUnMaskedAttentionRelative2dSharedOneRow(
self, batch, height, width, num_heads, max_relative_position, depth):
heads_share_relative_embedding = True
q = np.random.rand(batch, num_heads, height, width, depth)
k = np.random.rand(batch, num_heads, height, width, depth)
v = np.random.rand(batch, num_heads, height, width, depth)
a = common_attention.dot_product_unmasked_self_attention_relative_2d(
tf.constant(q, dtype=tf.float32),
tf.constant(k, dtype=tf.float32),
tf.constant(v, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=heads_share_relative_embedding)
self.evaluate(tf.global_variables_initializer())
(res, height_key_relative_embeddings,
width_key_relative_embeddings) = self.evaluate(a)
att_output = self.python_relative_att(
q, k, v, batch, num_heads, height, width, depth,
height_key_relative_embeddings, width_key_relative_embeddings,
heads_share_relative_embedding)
self.assertEqual(res.shape,
(batch, num_heads, height, width, depth))
self.assertAllClose(res, att_output)
@test_utils.run_in_graph_and_eager_modes()
def testRelativeAttentionV2Unmasked(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 16, 7)
y = np.random.rand(5, 4, 16, 7)
max_relative_position = 3
a = common_attention.dot_product_unmasked_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=False)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 16, 7))
@test_utils.run_in_graph_and_eager_modes()
def testRelativeAttentionV2UnmaskedSharedRel(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 16, 7)
y = np.random.rand(5, 4, 16, 7)
max_relative_position = 3
a = common_attention.dot_product_unmasked_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=True)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 16, 7))
@test_utils.run_in_graph_and_eager_modes()
def testRelativeAttentionV2UnmaskedRelativeLargerThanLength(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 3, 7)
y = np.random.rand(5, 4, 3, 7)
max_relative_position = 16
a = common_attention.dot_product_unmasked_self_attention_relative_v2(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
None,
max_relative_position=max_relative_position,
heads_share_relative_embedding=False)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 3, 7))
@test_utils.run_in_graph_and_eager_modes()
def testMaskedRelativeLocalAttentionV2(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 16, 7)
y = np.random.rand(5, 4, 16, 7)
block_length = 3
a = common_attention.masked_relative_local_attention_1d(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
block_length=block_length,
heads_share_relative_embedding=True,
add_relative_to_values=False,
name="masked_relative_local_attention_1d")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 16, 7))
@test_utils.run_in_graph_and_eager_modes()
def testMaskedRelativeLocalAttentionV2AddRelativeValues(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 4, 16, 7)
y = np.random.rand(5, 4, 16, 7)
block_length = 3
a = common_attention.masked_relative_local_attention_1d(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
block_length=block_length,
heads_share_relative_embedding=True,
add_relative_to_values=False,
name="masked_relative_local_attention_1d")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 4, 16, 7))
@test_utils.run_in_graph_and_eager_modes()
def testMaskedRelativeLocalAttentionV2SeqShorterThanBlockLength(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 7, 2, 7)
y = np.random.rand(5, 7, 2, 7)
block_length = 3
a = common_attention.masked_relative_local_attention_1d(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
block_length=block_length,
heads_share_relative_embedding=True,
name="masked_relative_local_attention_1d")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 7, 2, 7))
@test_utils.run_in_graph_and_eager_modes()
def testMaskedRelativeLocalAttentionV2SeqShorterThanTwiceBlockLength(self):
# (batch, heads, length, depth)
x = np.random.rand(5, 7, 5, 7)
y = np.random.rand(5, 7, 5, 7)
block_length = 3
a = common_attention.masked_relative_local_attention_1d(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
block_length=block_length,
heads_share_relative_embedding=True,
name="masked_relative_local_attention_1d")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(a)
self.assertEqual(res.shape, (5, 7, 5, 7))
def testBiasBatchCoordinates(self):
"""Testing the batch coordinates mask."""
q = tf.constant([0, 0, 1, 1, 1, 1, 2, 2, 2], dtype=tf.int32)
q = tf.expand_dims(q, axis=-1)
k = tf.constant([0, 0, 0, 2, 2, 3, 3, 3], dtype=tf.int32)
k = tf.expand_dims(k, axis=-1)
ground_truth = np.array([
[0, 0, 0, 1, 1, 1, 1, 1], # 0
[0, 0, 0, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1], # 1 (just masked)
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 1, 1, 1], # 2
[1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 1, 1, 1],
], np.float32) * -1e9
bias = common_attention.attention_bias_coordinates(q, k)
self.assertAllClose(self.evaluate(bias), ground_truth)
@test_utils.run_in_graph_and_eager_modes()
def testBiasFuture(self):
"""Testing the sequence order mask."""
q = tf.constant([0, 1, 2, 3, 0, 1, 2, 0, 1], dtype=tf.int32)
q = tf.expand_dims(q, axis=-1)
k = tf.constant([0, 1, 2, 3, 4, 0, 1, 2], dtype=tf.int32)
k = tf.expand_dims(k, axis=-1)
ground_truth = np.array([
[0, 1, 1, 1, 1, 0, 1, 1], # 0
[0, 0, 1, 1, 1, 0, 0, 1], # 1
[0, 0, 0, 1, 1, 0, 0, 0], # 2
[0, 0, 0, 0, 1, 0, 0, 0], # 3
[0, 1, 1, 1, 1, 0, 1, 1], # 0
[0, 0, 1, 1, 1, 0, 0, 1], # 1
[0, 0, 0, 1, 1, 0, 0, 0], # 2
[0, 1, 1, 1, 1, 0, 1, 1], # 0
[0, 0, 1, 1, 1, 0, 0, 1], # 1
], np.float32) * -1e9
bias = common_attention.attention_bias_future(q, k)
self.assertAllClose(self.evaluate(bias), ground_truth)
@test_utils.run_in_graph_mode_only()
def testMultiheadAttentionWithLayerCollection(self):
"""Testing multihead attention with layer collection for kfac."""
x = tf.zeros([3, 4, 5], tf.float32)
layer_collection = kfac.LayerCollection()
common_attention.multihead_attention(
x, None, None, 10, 10, 10, 2, 0.2,
layer_collection=layer_collection)
self.assertLen(layer_collection.get_blocks(), 4)
@parameterized.named_parameters(
("", 1, 1, 8, 4, 3),
("dynamic_batch", None, 1, 8, 4, 2),
("batches", 4, 3, 8, 4, 2),
("block_length", 1, 1, 8, 4, 4),
)
def testDilatedAttention(self, batch, heads, length, depth_v, block_length):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, depth_v])
k = tf.random_normal([batch, heads, length, depth_v])
v = tf.random_normal([batch, heads, length, depth_v])
output = common_attention.dilated_self_attention_1d(
q, k, v,
query_block_size=block_length,
memory_block_size=block_length,
gap_size=2,
num_memory_blocks=2)
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, depth_v))
@parameterized.named_parameters(
("", 1, 1, 8, 4, 3),
("dynamic_batch", None, 1, 8, 4, 2),
("batches", 4, 3, 8, 4, 2),
("block_length", 1, 1, 8, 4, 4),
)
def testMaskedDilatedAttention(self, batch, heads, length, depth_v,
block_length):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, depth_v])
k = tf.random_normal([batch, heads, length, depth_v])
v = tf.random_normal([batch, heads, length, depth_v])
output = common_attention.masked_dilated_self_attention_1d(
q, k, v,
query_block_size=block_length,
memory_block_size=block_length,
gap_size=2,
num_memory_blocks=2)
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, depth_v))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/common_audio.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for audio."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import scipy.signal
import tensorflow.compat.v1 as tf
def add_delta_deltas(filterbanks, name=None):
"""Compute time first and second-order derivative channels.
Args:
filterbanks: float32 tensor with shape [batch_size, len, num_bins, 1]
name: scope name
Returns:
float32 tensor with shape [batch_size, len, num_bins, 3]
"""
delta_filter = np.array([2, 1, 0, -1, -2])
delta_delta_filter = scipy.signal.convolve(delta_filter, delta_filter, "full")
delta_filter_stack = np.array(
[[0] * 4 + [1] + [0] * 4, [0] * 2 + list(delta_filter) + [0] * 2,
list(delta_delta_filter)],
dtype=np.float32).T[:, None, None, :]
delta_filter_stack /= np.sqrt(
np.sum(delta_filter_stack**2, axis=0, keepdims=True))
filterbanks = tf.nn.conv2d(
filterbanks, delta_filter_stack, [1, 1, 1, 1], "SAME", data_format="NHWC",
name=name)
return filterbanks
def compute_mel_filterbank_features(
waveforms,
sample_rate=16000, dither=1.0 / np.iinfo(np.int16).max, preemphasis=0.97,
frame_length=25, frame_step=10, fft_length=None,
window_fn=functools.partial(tf.signal.hann_window, periodic=True),
lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80,
log_noise_floor=1e-3, apply_mask=True):
"""Implement mel-filterbank extraction using tf ops.
Args:
waveforms: float32 tensor with shape [batch_size, max_len]
sample_rate: sampling rate of the waveform
dither: stddev of Gaussian noise added to waveform to prevent quantization
artefacts
preemphasis: waveform high-pass filtering constant
frame_length: frame length in ms
frame_step: frame_Step in ms
fft_length: number of fft bins
window_fn: windowing function
lower_edge_hertz: lowest frequency of the filterbank
upper_edge_hertz: highest frequency of the filterbank
num_mel_bins: filterbank size
log_noise_floor: clip small values to prevent numeric overflow in log
apply_mask: When working on a batch of samples, set padding frames to zero
Returns:
filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1]
"""
# `stfts` is a complex64 Tensor representing the short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [batch_size, ?, fft_unique_bins]
# where fft_unique_bins = fft_length // 2 + 1
# Find the wave length: the largest index for which the value is !=0
# note that waveforms samples that are exactly 0.0 are quite common, so
# simply doing sum(waveforms != 0, axis=-1) will not work correctly.
wav_lens = tf.reduce_max(
tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) *
tf.to_int32(tf.not_equal(waveforms, 0.0)),
axis=-1) + 1
if dither > 0:
waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither)
if preemphasis > 0:
waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1]
wav_lens -= 1
frame_length = int(frame_length * sample_rate / 1e3)
frame_step = int(frame_step * sample_rate / 1e3)
if fft_length is None:
fft_length = int(2**(np.ceil(np.log2(frame_length))))
stfts = tf.signal.stft(
waveforms,
frame_length=frame_length,
frame_step=frame_step,
fft_length=fft_length,
window_fn=window_fn,
pad_end=True)
stft_lens = (wav_lens + (frame_step - 1)) // frame_step
masks = tf.to_float(tf.less_equal(
tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0),
tf.expand_dims(stft_lens, 1)))
# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [batch_size, ?, 257].
magnitude_spectrograms = tf.abs(stfts)
# Warp the linear-scale, magnitude spectrograms into the mel-scale.
num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
linear_to_mel_weight_matrix = (
tf.signal.linear_to_mel_weight_matrix(
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
upper_edge_hertz))
mel_spectrograms = tf.tensordot(
magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
# Note: Shape inference for tensordot does not currently handle this case.
mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
linear_to_mel_weight_matrix.shape[-1:]))
log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms))
if apply_mask:
log_mel_sgram *= tf.expand_dims(tf.to_float(masks), -1)
return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams")
================================================
FILE: tensor2tensor/layers/common_hparams.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters and ranges common to multiple models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import zip # pylint: disable=redefined-builtin
from tensor2tensor.utils import hparam
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
@registry.register_hparams("basic_1")
def basic_params1():
"""A set of basic hyperparameters."""
return hparam.HParams(
# If the problem consists of variable-length sequences
# (see problem.batch_size_means_tokens()), then this is the number
# of tokens per batch per GPU or per TPU core. Otherwise, this is
# the number of examples per GPU or per TPU core.
batch_size=4096,
batch_shuffle_size=512,
# If True, then if the features are of variable length, the batch_size is
# used as the actual batch size (and not tokens per batch).
use_fixed_batch_size=False,
num_hidden_layers=4,
kernel_height=3,
kernel_width=1,
hidden_size=64,
compress_steps=0,
# All hyperparameters ending in "dropout" are automatically set to 0.0
# when not in training mode.
dropout=0.2,
clip_grad_norm=2.0,
grad_noise_scale=0.0,
summarize_grads=False,
# Flag for whether mlperf mode is on
mlperf_mode=False,
# Whether to log the name and size of every variable
summarize_vars=False,
initializer="orthogonal",
initializer_gain=1.5,
label_smoothing=0.1,
optimizer="adam",
optimizer_adam_epsilon=1e-6,
optimizer_adam_beta1=0.85,
optimizer_adam_beta2=0.997,
optimizer_momentum_momentum=0.9,
optimizer_momentum_nesterov=False,
optimizer_adafactor_beta1=0.0,
optimizer_adafactor_beta2=0.999,
optimizer_adafactor_factored=True,
optimizer_adafactor_decay_type="pow",
optimizer_adafactor_memory_exponent=0.8,
optimizer_adafactor_clipping_threshold=1.0,
optimizer_adafactor_multiply_by_parameter_scale=True,
# Number of accumulating steps for multi step optimizers.
optimizer_multistep_accumulate_steps=0,
# Loss scaling used.
# Generally only necessary with mixed precision training.
# Mixed precision training only supports exponential scaling currently
# To disable the scaler, see to 0/False
mixed_precision_optimizer_loss_scaler="exponential",
# Determines the initial loss scaling value for mixed precision
mixed_precision_optimizer_init_loss_scale=2**15,
# Whether to zero gradients that were not computed, so that the
# appropriate slots are created. Useful for sharing checkpoints between
# models with different sets of heads.
optimizer_zero_grads=False,
weight_decay=1e-6,
weight_noise=0.0,
# Defines the learning rate as a product of named functions.
# Available functions are listed in learning_rate._LEARNING_RATE_FUNCTIONS
# e.g. "constant*linear_warmup*rsqrt_decay*rsqrt_hidden_size"
learning_rate_schedule="legacy",
learning_rate_constant=1.0,
# If learning_rate_schedule=="legacy",
# then we specify decay scheme here. Warmup is always exponential,
# except with "noam" learning rate decay scheme.
# see optimize.legacy_learning_rate_schedule()
# TODO(noam): migrate everyone away from this.
learning_rate_decay_scheme="none",
# decay_steps and decay_staircase for learning_rate_decay_scheme=="exp"
learning_rate_decay_steps=5000,
learning_rate_decay_staircase=False,
learning_rate_minimum=None,
learning_rate_decay_rate=1.0,
learning_rate_warmup_steps=100,
learning_rate_cosine_cycle_steps=250000,
learning_rate=0.1,
sampling_method="argmax", # "argmax" or "random"
sampling_temp=1.0, # temperature for sampling
sampling_keep_top_k=-1, # If >0, ignore all but the top k logits
# expand the logits a piece at a time - saves memory.
factored_logits=False,
multiply_embedding_mode="sqrt_depth",
# Parameters related to mixtures of experts.
moe_hidden_sizes="2048", # hidden layer sizes (comma-separated)
moe_num_experts=64, # number of experts per layer
moe_k=2, # how many experts to use for each batch element
moe_loss_coef=1e-2,
# Sequences of operations to perform on layer input and layer output.
# Used by common_layers.layer_preprocess, common_layers.layer_postprocess
# Each character represents an operation:
# none: no preprocessing
# d: apply dropout
# n: apply normalization (see norm_type and norm_epsilon)
# a: add layer input (residual connection - only during postprocess)
# The special string "none" is used instead of the empty string
# to indicate no pre/postprocessing, since the empty string causes
# trouble for hyperparameter tuning.
# TODO(noam): The current settings ("", "dan") are the published version
# of the transformer. ("n", "da") seems better for harder-to-learn
# models, so it should probably be the default.
layer_preprocess_sequence="none",
layer_postprocess_sequence="dan",
# dropout rate to use during layer_preprocess and layer_postprocess
layer_prepostprocess_dropout=0.1,
# broadcast dimensions for layer_prepostprocess_dropout
# a comma-separated list of integers.
# see common_layers.dropout_with_broadcast_dims()
# Change this to "1" to save memory.
layer_prepostprocess_dropout_broadcast_dims="",
# dropout some symbols (set them to 0) before embedding.
symbol_dropout=0.0,
# What type of normalization to use
norm_type="layer", # "batch", layer", "noam", "none".
# epsilon parameter to normalization function
norm_epsilon=1e-6,
# pad vocabularies so that this value divides the vocabulary size.
vocab_divisor=1,
# During training, we drop sequences whose inputs and targets are shorter
# than min_length
min_length=0,
# During training, we drop sequences whose inputs or targets are longer
# than max_length.
# If max_length==0, we use hparams.batch_size instead.
max_length=0,
# Pack examples on the fly.
pack_dataset=False,
# Use custom ops not included in standard tensorflow.
use_custom_ops=True,
# Split targets on the first axis into chunks of this length.
split_targets_chunk_length=0,
split_targets_max_chunks=100,
split_targets_strided_training=False,
# Maximum length in the smallest length bucket. Setting this
# flag too high will result in wasteful padding of short
# sequences. Due to some (hopefully) temporary hacks in the
# data reading and batching code, setting this flag too low
# results in a very long batch-shuffling queue.
# TODO(noam): change this once the Datasets API changes.
min_length_bucket=8,
# This flag controls the number of length buckets in the data
# reader. The buckets have maximum lengths from
# min_bucket_length to (max_length or batch_size), increasing
# (approximately) by factors of length_bucket_step.
length_bucket_step=1.1,
# If set to True, drop sequences longer than max_length during eval.
# This affects the validity of the evaluation metrics.
eval_drop_long_sequences=False,
# If True, run the model autoregressively instead of teacher-forcing
# during eval
eval_run_autoregressive=False,
# (For features with symbol modality) If True, share all of the
# input embeddings, target embeddings, and softmax weights.
shared_embedding_and_softmax_weights=False,
# (For features with symbol modality) If True, share the input embeddings
# and target embeddings.
shared_embedding=False,
# (For features with symbol modality) Number to shard embeddings by.
symbol_modality_num_shards=1,
# Feature transformations are optional dictionaries comprising key-value
# pairs of a feature name (str) and its transformation (function). If not
# specified, T2TModel applies a default transformation according to the
# feature's modality. Bottom is applicable to all features; loss, top, and
# weights_fn are only applicable to target features.
# TODO(trandustin): `name` is an optional hparam for legacy reasons,
# defining variable scope names. Remove this hparam in the future.
bottom={},
loss={},
name={},
top={},
weights_fn={},
# The maximum length of "input" sequence.
# Sequences longer than this value will be truncated. 0 or negative values
# mean there is no maximum or truncation.
# You can change this behavior by overriding preprocess_example() method
# in your problem class.
max_input_seq_length=0,
# The maximum length of "target" sequence.
# Sequences longer than this value will be truncated. 0 or negative values
# mean there is no maximum or truncation.
# You can change this behavior by overriding preprocess_example() method
# in your problem class.
max_target_seq_length=0,
# if nonzero, we split the target sequences on example read.
# This is for use with language modeling problems with fixed length
# examples. e.g. The examples may be written with length 65536, but we
# want to split each example into 64 examples of length 1024.
split_to_length=0,
# Video settings: how many frames to batch on input and targets.
video_num_input_frames=1,
video_num_target_frames=1,
# This flag allows us to optionally treat a seq-to-seq problem
# as a language model. Legal values are:
#
# "none" - Do not prepend the inputs to the targets.
# "prepend_inputs_masked_attention"
# replace "targets" in preprocessing with
# tf.concat([inputs, [0], targets], axis=1)
# i.e. we prepend the inputs to the targets with a single
# padding token in between. Use masked self-attention on the
# entire resulting sequence. During training, we compute losses on
# the combined sequence. During eval, we compute the metrics
# on only the targets portion.
# "prepend_inputs_full_attention"
# similar to the previous option except that each
# position in the inputs portion can see the
# entire inputs portion. This removes the challenge of
# autoregressively predicting the inputs portion.
prepend_mode="none",
# Scheduled sampling is interesting for auto-regressive models.
# It runs an additional step using the generated output as autoregressive
# targets, which can improve the models inference results later. The
# parameter scheduled_sampling_prob determines with what probability
# will such additional step be run. It's turned off (0.0) by default.
# This probability will exponentially warm up for the number of
# steps determined by scheduled_sampling_warmup_steps.
# The tensor used for the n-th pass will consist of outputs from
# the (n-1)-th pass mixed with gold truth, with the proportion of gold
# determined by scheduled_sampling_gold_mixin_prob. Control the number
# of passes with scheduled_sampling_num_passes.
scheduled_sampling_prob=0.0,
scheduled_sampling_method="parallel", # parallel or sequential.
scheduled_sampling_warmup_steps=50000,
scheduled_sampling_gold_mixin_prob=0.5,
scheduled_sampling_num_passes=1,
scheduled_sampling_warmup_schedule="exp", # exp, linear, or sigmoid.
# This setting controls whether to copy variables around in a daisy chain
# (if true) or leave their placement to TensorFlow. It only affects multi
# device training and mostly should be turned on for performance. One
# exception are recurrent models: with dynamic loops it must be off.
daisy_chain_variables=True,
# If True in PREDICT mode, then last-position-only optimizations are not
# used.
force_full_predict=False,
# Set this for pure model parallelism. There is only one data shard.
no_data_parallelism=False,
# dtype used for activations. - "float32" or "bfloat16"
# activation_dtype="bfloat16" currently only works on TPU.
# It lowers activation-memory usage
# and does not appear to affect quality.
# You can train on TPU with activation_dtype="bfloat16" and evaluate
# on CPU/GPU with activation_dtype="float32"
activation_dtype="float32",
# dtype used for parameters: "float32" or "bfloat16"
# bfloat16 currently only works with optimizer="adafactor".
# The savings in memory allow for training larger models.
# Weights are encoded as (w*128)^8, using pseudostochastic
# roundoff. Initial experiments show that model quality is similar
# to baseline for about 3M training steps, but worse thereafter.
weight_dtype="float32",
# Directory containing a checkpoint for a pretrained model. This will only
# be used if a new run is being started. Parameters not found in the
# pretrained model will be randomly initialized. Superfluous parameters in
# the pretrained model will be ignored.
pretrained_model_dir="",
# Threshold used for two cases: the primary task probability for the
# constant mixing schedule, and the exponential schedule limit for when
# mixing should stop (eg: 0.5 means stop at 50-50 mixing, 0.8 means stop
# at 20-80 mixing for the primary-others mixing case.)
multiproblem_schedule_threshold=0.5,
# For more than 2 tasks, we may want to specify per-task thresholds here.
# In that case, this needs to be a string with as many floating point
# numbers as the number of tasks in the multi-problem. These numbers
# are later normalized to add up to 1 and taken as probabilities for
# each task. This enforces a constant mixing schedule and if this is
# empty then the threshold from above is used for the first task and
# the other tasks get the remaining probability split uniformly.
multiproblem_per_task_threshold="",
# The number of examples at which the proportion of the mixed in datasets
# is multiproblem_schedule_threshold
multiproblem_schedule_max_examples=1e7,
# When training multiproblems, we can mix the data according to different
# schedules. Example: a constant schedule mixing 20-80 between the primary
# and other tasks.
# A list of supported schedules can be found in
# `data_generators.multi_problem.py`.
multiproblem_mixing_schedule="constant",
# A boolean that decides whether input sequence losses and target label
# losses in classification problems should be reweighted.
multiproblem_reweight_label_loss=False,
# How much weight the targets in classification problems receive. Inputs
# receive 1 minus this weight.
multiproblem_label_weight=0.5,
# Hyperparameters for relative attention.
# The maximum relative positional distance to learn an embedding for.
max_relative_position=0,
# If heads share the same relative embedding.
heads_share_relative_embedding=False,
# If relative embedding terms are added to values too.
add_relative_to_values=False,
# If enable the host_call which is executed every training step.
# There could be a performance drop if host_call function is slow and
# cannot keep up with the TPU-side computation.
tpu_enable_host_call=False,
# Pad batch dim of inputs to nearest multiple of batch multiple.
pad_batch=False,
# When true, do not evaluate on the language model data when running the
# multiproblem since it can take a while. If False, set eval_steps to
# something large like 6000 or 10000.
multiproblem_target_eval_only=False,
# Max out the vocab size to a power of 2 for efficiency and to reserve
# extra space in the vocabulary for new task ids and label classes.
multiproblem_vocab_size=-1,
# When using multiproblem with generation tasks, need to truncate the
# inputs and targets manually before concatenating them.
multiproblem_max_input_length=-1,
multiproblem_max_target_length=-1,
# If positive, makes training targets fixed-length in MultiProblem.
multiproblem_fixed_train_length=-1,
# Load weights from a second model. For instance, when using
# pre-trained weights, you might want to initialize the encoder
# and decoder by loading different models.
warm_start_from_second="",
# Area attention hyper parameters
area_value_mode="none",
area_key_mode="none",
# Using area attention for the number of layers from the bottom
num_area_layers=0,
max_area_width=1,
max_area_height=1,
memory_height=1,
# Whether to use GPU automatic mixed precision (via graph rewrite)
gpu_automatic_mixed_precision=False,
)
class RangedHParams(object):
"""Defines parameter ranges for tuning."""
# From ParameterConfig proto
LINEAR_SCALE = 1
LOG_SCALE = 2
REVERSE_LOG_SCALE = 3
SCALES_STR = {
LINEAR_SCALE: "UNIT_LINEAR_SCALE",
LOG_SCALE: "UNIT_LOG_SCALE",
REVERSE_LOG_SCALE: "UNIT_REVERSE_LOG_SCALE",
}
def __init__(self):
self._categorical_params = {}
self._discrete_params = {}
self._float_params = {}
self._int_params = {}
def _check_reset_and_type_change(self, name, orig_ctr):
"""Check if name is in orig_ctr or in one of the other type containers."""
# Resetting a hyperparameter
if name in orig_ctr:
tf.logging.warning("Overwriting hparam %s", name)
ctr_names = [
(self._categorical_params, "categorical"),
(self._discrete_params, "discrete"),
(self._float_params, "float"),
(self._int_params, "int"),
]
ctrs, names = list(zip(*ctr_names))
orig_name = names[ctrs.index(orig_ctr)]
for ctr, ctr_name in ctr_names:
if ctr is orig_ctr:
continue
# Using a different type for the same hyperparameter name
if name in ctr:
raise ValueError("Setting hyperparameter %s as type %s, but a "
"hyperparemeter of the same name was originally "
"registered as type %s" % (name, ctr_name, orig_name))
def set_categorical(self, name, categories, length=None):
self._check_reset_and_type_change(name, self._categorical_params)
self._categorical_params[name] = (name, categories, length)
def set_discrete(self, name, feasible_points, scale=None, length=None):
self._check_reset_and_type_change(name, self._discrete_params)
self._discrete_params[name] = (name, feasible_points, scale, length)
def set_float(self, name, min_val, max_val, scale=None, length=None):
self._check_reset_and_type_change(name, self._float_params)
self._float_params[name] = (name, min_val, max_val, scale, length)
def set_int(self, name, min_val, max_val, scale=None, length=None):
self._check_reset_and_type_change(name, self._int_params)
self._int_params[name] = (name, min_val, max_val, scale, length)
def fix_select_params(self, hp):
ctrs = [
self._categorical_params, self._discrete_params, self._float_params,
self._int_params
]
for key, val in hp.values().iteritems():
for ctr in ctrs:
if key in ctr:
del ctr[key]
self.set_discrete(key, [val])
def to_parameter_specs(self, name_prefix=""):
"""To list of dicts suitable for Cloud ML Engine hyperparameter tuning."""
specs = []
for name, categories, _ in self._categorical_params.values():
spec = {
"parameterName": name_prefix + name,
"type": "CATEGORICAL",
"categoricalValues": categories,
}
specs.append(spec)
for name, feasible_points, scale, _ in self._discrete_params.values():
spec = {
"parameterName": name_prefix + name,
"type": "DISCRETE",
"discreteValues": feasible_points,
}
if scale:
spec["scaleType"] = self.SCALES_STR[scale]
specs.append(spec)
for name, min_val, max_val, scale, _ in self._float_params.values():
spec = {
"parameterName": name_prefix + name,
"type": "DOUBLE",
"minValue": min_val,
"maxValue": max_val,
}
if scale:
spec["scaleType"] = self.SCALES_STR[scale]
specs.append(spec)
for name, min_val, max_val, scale, _ in self._int_params.values():
spec = {
"parameterName": name_prefix + name,
"type": "INTEGER",
"minValue": min_val,
"maxValue": max_val,
}
if scale:
spec["scaleType"] = self.SCALES_STR[scale]
specs.append(spec)
return specs
@registry.register_ranged_hparams("basic1")
def basic_range1(ranged_hparams):
"""A basic range of hyperparameters."""
rhp = ranged_hparams
rhp.set_discrete("batch_size", [1024, 2048, 4096])
rhp.set_discrete("num_hidden_layers", [1, 2, 3, 4, 5, 6])
rhp.set_discrete("hidden_size", [32, 64, 128, 256, 512], scale=rhp.LOG_SCALE)
rhp.set_discrete("kernel_height", [1, 3, 5, 7])
rhp.set_discrete("kernel_width", [1, 3, 5, 7])
rhp.set_discrete("compress_steps", [0, 1, 2])
rhp.set_float("dropout", 0.0, 0.5)
rhp.set_float("weight_decay", 1e-4, 10.0, scale=rhp.LOG_SCALE)
rhp.set_float("label_smoothing", 0.0, 0.2)
rhp.set_float("clip_grad_norm", 0.01, 50.0, scale=rhp.LOG_SCALE)
rhp.set_float("learning_rate", 0.005, 2.0, scale=rhp.LOG_SCALE)
rhp.set_categorical("initializer",
["uniform", "orthogonal", "uniform_unit_scaling"])
rhp.set_float("initializer_gain", 0.5, 3.5)
rhp.set_categorical("learning_rate_decay_scheme",
["none", "sqrt", "noam", "exp"])
rhp.set_float("optimizer_adam_epsilon", 1e-7, 1e-2, scale=rhp.LOG_SCALE)
rhp.set_float("optimizer_adam_beta1", 0.8, 0.9)
rhp.set_float("optimizer_adam_beta2", 0.995, 0.999)
rhp.set_categorical(
"optimizer",
["adam", "adagrad", "momentum", "rms_prop", "sgd", "yellow_fin"])
@registry.register_ranged_hparams
def basic_moe_range(rhp):
"""Moe range; when this parameter is unused, it allows us to see variance."""
rhp.set_float("moe_loss_coef", 0.01, 0.02)
================================================
FILE: tensor2tensor/layers/common_image_attention.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for attention mechanism for images."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import expert_utils
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class AttentionType(object):
"""Types of attention type used in cia."""
LOCAL_1D = "local_1d"
LOCAL_2D = "local_2d"
GLOBAL = "global"
GLOCAL = "global_local"
DILATED = "dilated"
MOE_LOCAL_1D = "moe_local1d"
LOCAL_BLOCK = "local_block"
NON_CAUSAL_1D = "local_1d_noncausal"
RELATIVE_LOCAL_1D = "rel_local_1d"
@staticmethod
def get_choices():
return [
AttentionType.GLOBAL,
AttentionType.GLOCAL,
AttentionType.MOE_LOCAL_1D,
AttentionType.LOCAL_1D,
AttentionType.LOCAL_2D,
AttentionType.LOCAL_BLOCK,
AttentionType.DILATED,
AttentionType.NON_CAUSAL_1D,
AttentionType.RELATIVE_LOCAL_1D,
]
class DistributionType(object):
"""Types of distributions used in cia."""
CAT = "cat"
DMOL = "dmol"
@staticmethod
def get_choices():
return [
DistributionType.CAT,
DistributionType.DMOL,
]
def maybe_reshape_4d_to_3d(x):
"""Reshape input from 4D to 3D if necessary."""
x_shape = common_layers.shape_list(x)
is_4d = False
if len(x_shape) == 4:
x = tf.reshape(x, [x_shape[0], x_shape[1]*x_shape[2], x_shape[3]])
is_4d = True
return x, x_shape, is_4d
def local_attention_2d(x, hparams, attention_type="local_attention_2d"):
"""Local 2d, self attention layer."""
# self-attention
with tf.variable_scope("local_2d_self_att"):
y = common_attention.multihead_attention_2d(
x,
None,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
attention_type=attention_type,
query_shape=hparams.query_shape,
memory_flange=hparams.memory_flange,
name="self_attention")
return y
def local_within_block_attention(x,
self_attention_bias,
hparams,
attention_type="local_within_block_mask_right",
q_padding="VALID",
kv_padding="VALID"):
"""Local within block self attention."""
x_new, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
with tf.variable_scope("local_within_block"):
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x_new, hparams),
None,
self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=attention_type,
block_width=hparams.block_width,
block_length=hparams.block_length,
q_padding=q_padding,
kv_padding=kv_padding,
q_filter_width=hparams.q_filter_width,
kv_filter_width=hparams.kv_filter_width,
name="local_within_block")
if is_4d:
y = tf.reshape(y, x_shape)
return y
def local_attention_1d(x,
hparams,
attention_type="local_unmasked",
q_padding="VALID",
kv_padding="VALID"):
"""Local 1d self attention."""
# self-attention
x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
with tf.variable_scope("local_1d_self_att"):
y = common_attention.multihead_attention(
x,
None,
None,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=attention_type,
shared_rel=hparams.shared_rel,
block_width=hparams.block_width,
block_length=hparams.block_length,
q_padding=q_padding,
kv_padding=kv_padding,
q_filter_width=hparams.q_filter_width,
kv_filter_width=hparams.kv_filter_width,
make_image_summary=False,
name="self_attention")
if is_4d:
y = tf.reshape(y, x_shape)
return y
def get_dilated_1d_attention_mask(
num_heads, block_size,
num_blocks, memory_size, gap_size,
name="dilated_mask"):
"""Dilated attention with a masking strategy."""
mask = np.ones((num_heads, block_size, 2*block_size), bool)
# now going over every row to do the right assignment of
# memory blocks
for i in range(block_size):
visible = 2*block_size - (block_size-i)
# You always attend to yourself, set the mask for that
mask[:, i, -(block_size - i)] = 0
# Maybe num_blocks can be automatically calculated?
for j in range(num_blocks):
for k in range(memory_size):
index = ((gap_size + memory_size)*j) + k
if index >= visible:
break
mask[:, i, -(index + block_size - i + 1)] = 0 # Verify
# adding a num blocks dimension
mask = np.expand_dims(mask, axis=1)
return tf.constant(mask, dtype=tf.int32, name=name)
def dilated_attention_1d(x,
hparams,
attention_type="masked_dilated_1d",
q_padding="VALID",
kv_padding="VALID",
gap_size=2):
"""Dilated 1d self attention."""
# self-attention
x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
with tf.variable_scope("masked_dilated_1d"):
y = common_attention.multihead_attention(
x,
None,
None,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=attention_type,
block_width=hparams.block_width,
block_length=hparams.block_length,
q_padding=q_padding,
kv_padding=kv_padding,
q_filter_width=hparams.q_filter_width,
kv_filter_width=hparams.kv_filter_width,
gap_size=gap_size,
num_memory_blocks=hparams.num_memory_blocks,
name="self_attention")
if is_4d:
y = tf.reshape(y, x_shape)
y.set_shape([None, None, None, hparams.hidden_size])
return y
def local_global_attention(x,
self_attention_bias,
hparams,
q_padding="LEFT",
kv_padding="LEFT"):
"""Local and global 1d self attention."""
with tf.variable_scope("self_local_global_att"):
[x_global, x_local] = tf.split(x, 2, axis=-1)
split_hidden_size = int(hparams.hidden_size / 2)
split_heads = int(hparams.num_heads / 2)
if self_attention_bias is not None:
self_attention_bias = get_self_attention_bias(x)
y_global = common_attention.multihead_attention(
x_global,
None,
self_attention_bias,
hparams.attention_key_channels or split_hidden_size,
hparams.attention_value_channels or split_hidden_size,
split_hidden_size,
split_heads,
hparams.attention_dropout,
q_filter_width=hparams.q_filter_width,
kv_filter_width=hparams.kv_filter_width,
q_padding=q_padding,
kv_padding=kv_padding,
name="global_self_att")
y_local = common_attention.multihead_attention(
x_local,
None,
None,
hparams.attention_key_channels or split_hidden_size,
hparams.attention_value_channels or split_hidden_size,
split_hidden_size,
split_heads,
hparams.attention_dropout,
attention_type="local_masked",
block_length=hparams.block_length,
block_width=hparams.block_width,
q_filter_width=hparams.q_filter_width,
kv_filter_width=hparams.kv_filter_width,
q_padding=q_padding,
kv_padding=kv_padding,
name="local_self_att")
y = tf.concat([y_global, y_local], axis=-1)
return y
def full_self_attention(x,
self_attention_bias,
hparams,
q_padding="LEFT",
kv_padding="LEFT"):
"""Full self-attention layer."""
x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
if self_attention_bias is not None:
self_attention_bias = get_self_attention_bias(x)
with tf.variable_scope("self_att"):
y = common_attention.multihead_attention(
x,
None,
self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
q_filter_width=hparams.q_filter_width,
kv_filter_width=hparams.kv_filter_width,
q_padding=q_padding,
kv_padding=kv_padding,
name="self_att")
if is_4d:
y = tf.reshape(y, [x_shape[0], x_shape[1], x_shape[2], x_shape[3]])
y.set_shape([None, None, None, hparams.hidden_size])
return y
def encdec_attention_1d(x,
encoder_output,
encoder_decoder_attention_bias,
hparams):
"""Local 1d self attention."""
x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
encoder_output, _, _ = maybe_reshape_4d_to_3d(encoder_output)
with tf.variable_scope("encdec_attention"):
# Encoder Decoder attention
y = common_attention.multihead_attention(
x,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
name="encdec_attention")
if is_4d:
y = tf.reshape(y, x_shape)
y.set_shape([None, None, None, hparams.hidden_size])
return y
def transformer_decoder_layers(inputs,
encoder_output,
num_layers,
hparams,
self_attention_bias=None,
encoder_decoder_attention_bias=None,
attention_type=AttentionType.LOCAL_2D,
losses=None,
name="transformer"):
"""Multi layer transformer."""
x = inputs
x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout)
if attention_type == AttentionType.DILATED:
assert len(hparams.gap_sizes) == num_layers
for layer in range(num_layers):
with tf.variable_scope("%s_layer_%d" % (name, layer)):
# self-attention + skip connections
if attention_type == AttentionType.LOCAL_2D:
y = local_attention_2d(common_layers.layer_preprocess(x, hparams),
hparams,
attention_type="masked_local_attention_2d")
elif attention_type == AttentionType.LOCAL_1D:
y = local_attention_1d(common_layers.layer_preprocess(x, hparams),
hparams,
attention_type="local_mask_right",
q_padding="LEFT", kv_padding="LEFT")
elif attention_type == AttentionType.RELATIVE_LOCAL_1D:
y = local_attention_1d(
common_layers.layer_preprocess(x, hparams),
hparams,
attention_type="local_relative_mask_right",
q_padding="LEFT",
kv_padding="LEFT")
elif attention_type == AttentionType.NON_CAUSAL_1D:
y = local_attention_1d(common_layers.layer_preprocess(x, hparams),
hparams,
attention_type="local_unmasked",
q_padding="VALID", kv_padding="VALID")
elif attention_type == AttentionType.LOCAL_BLOCK:
y = local_within_block_attention(
common_layers.layer_preprocess(x, hparams),
self_attention_bias, hparams,
attention_type="local_within_block_mask_right",
q_padding="LEFT", kv_padding="LEFT")
elif attention_type == AttentionType.GLOCAL:
y = local_global_attention(common_layers.layer_preprocess(x, hparams),
self_attention_bias, hparams,
q_padding="LEFT", kv_padding="LEFT")
elif attention_type == AttentionType.DILATED:
y = dilated_attention_1d(common_layers.layer_preprocess(x, hparams),
hparams, q_padding="LEFT",
kv_padding="LEFT",
gap_size=hparams.gap_sizes[layer])
elif attention_type == AttentionType.GLOBAL:
y = full_self_attention(common_layers.layer_preprocess(x, hparams),
self_attention_bias, hparams,
q_padding="LEFT", kv_padding="LEFT")
x = common_layers.layer_postprocess(x, y, hparams)
# enc-dec attention + skip connections
if encoder_output is not None:
y = encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
encoder_output,
encoder_decoder_attention_bias,
hparams)
x = common_layers.layer_postprocess(x, y, hparams)
# feed-fwd layers + skip connections
y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams,
losses=losses)
x = common_layers.layer_postprocess(x, y, hparams)
return common_layers.layer_preprocess(x, hparams)
def transformer_encoder_layers(inputs,
num_layers,
hparams,
attention_type=AttentionType.GLOBAL,
self_attention_bias=None,
q_padding="VALID",
kv_padding="VALID",
name="transformer"):
"""Multi layer transformer encoder."""
x = inputs
x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout)
for layer in range(num_layers):
# attention layers + skip connections
with tf.variable_scope("%s_layer_%d" % (name, layer)):
if attention_type == AttentionType.LOCAL_2D:
y = local_attention_2d(common_layers.layer_preprocess(x, hparams),
hparams,
attention_type="local_attention_2d")
elif attention_type == AttentionType.LOCAL_1D:
y = local_attention_1d(common_layers.layer_preprocess(x, hparams),
hparams,
attention_type="local_unmasked",
q_padding=q_padding, kv_padding=kv_padding)
elif attention_type == AttentionType.GLOBAL:
y = full_self_attention(common_layers.layer_preprocess(x, hparams),
self_attention_bias, hparams,
q_padding=q_padding, kv_padding=kv_padding)
x = common_layers.layer_postprocess(x, y, hparams)
# feed-fwd layer + skip connections
y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams)
x = common_layers.layer_postprocess(x, y, hparams)
return common_layers.layer_preprocess(x, hparams)
def ffn_layer(x, hparams, losses=None):
"""ffn layer transformer."""
with tf.variable_scope("ffn"):
if hparams.ffn_layer == "none":
return x
if hparams.ffn_layer == "conv_hidden_relu":
y = common_layers.dense_relu_dense(
x,
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.relu_dropout)
elif hparams.ffn_layer == "normed_conv_hidden_relu":
y = common_layers.normed_conv_hidden_relu(
x,
hparams.norm_type,
hparams.layer_norm_epsilon,
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.relu_dropout,
norm_name="convnorm")
elif hparams.ffn_layer == "self_attention_ffn":
x_shape = tf.shape(x)
x = tf.reshape(x, [x_shape[0], -1, hparams.hidden_size])
y = common_attention.ffn_self_attention_layer(
x, hparams.filter_size, hparams.hidden_size, hparams.num_parts,
hparams.attention_dropout, hparams.share_kv)
y = tf.reshape(y, x_shape)
elif hparams.ffn_layer == "local_moe_tpu":
overhead = (hparams.moe_overhead_train
if hparams.mode == tf_estimator.ModeKeys.TRAIN
else hparams.moe_overhead_eval)
x, x_shape, is_4d = maybe_reshape_4d_to_3d(x)
y, loss = expert_utils.local_moe_tpu(
x, hparams.filter_size // 2,
hparams.hidden_size,
hparams.moe_num_experts, overhead=overhead,
loss_coef=hparams.moe_loss_coef)
if is_4d:
y = tf.reshape(y, x_shape)
if losses is None:
raise ValueError(
"transformer_ffn_layer with type local_moe_tpu must pass in "
"a losses list")
losses.append(loss)
else:
assert hparams.ffn_layer == "glu_ffn"
y = common_layers.gated_linear_unit_layer(x)
return y
def get_self_attention_bias(x):
"""Creates masked self attention bias.
Args:
x: A tensor of shape [batch, length, depth]
Returns:
self_attention_bias: A tensor of shape [length, length, 1]
"""
x_shape = common_layers.shape_list(x)
self_attention_bias = common_attention.attention_bias_lower_triangle(
x_shape[1])
return self_attention_bias
def postprocess_image(x, rows, cols, hparams):
"""Postprocessing after decoding.
Args:
x: Tensor of shape [batch, ...], where ... can be any rank such that the
number of elements in x is batch * rows * cols * hparams.hidden_size.
rows: Integer representing number of rows in a 2-D data point.
cols: Integer representing number of columns in a 2-D data point.
hparams: HParams set.
Returns:
Tensor of shape [batch, rows, cols, depth], where depth is
hparams.num_mixtures * 10 if hparams.likelihood is DMOL, otherwise 256. In
the special case of inference and block raster scan order, it is a Tensor
of shape [batch, num_blocks_rows, num_block_cols, block_length, block_width,
depth].
"""
batch = common_layers.shape_list(x)[0]
x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size])
likelihood = getattr(hparams, "likelihood", DistributionType.CAT)
if likelihood == DistributionType.DMOL:
depth = hparams.num_mixtures * 10
targets = tf.layers.dense(x,
depth,
use_bias=False,
activation=None,
name="output_conv")
else:
depth = 256
targets = tf.layers.dense(x,
depth,
use_bias=True,
activation=None,
name="output_conv")
if (hparams.mode == tf_estimator.ModeKeys.PREDICT and
hparams.block_raster_scan):
y = targets
yshape = common_layers.shape_list(y)
block_length = hparams.query_shape[0]
block_width = hparams.query_shape[1]
# Break into block row wise.
y = tf.reshape(y,
[batch, yshape[1] // block_length, block_length,
yshape[2], depth])
yshape = common_layers.shape_list(y)
# Break into blocks width wise.
y_blocks = tf.reshape(y,
[batch, yshape[1], yshape[2],
yshape[3] // block_width, block_width, depth])
# Reshape targets as [batch, num_blocks_rows, num_block_cols, block_length,
# block_width, depth].
targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5])
return targets
def prepare_encoder(inputs, hparams, attention_type="local_1d"):
"""Prepare encoder for images."""
x = prepare_image(inputs, hparams, name="enc_channels")
# Add position signals.
x = add_pos_signals(x, hparams, "enc_pos")
x_shape = common_layers.shape_list(x)
if attention_type == "local_1d":
x = tf.reshape(x, [x_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size])
x.set_shape([None, None, hparams.hidden_size])
elif attention_type == "local_2d":
x.set_shape([None, None, None, hparams.hidden_size])
return x
def prepare_decoder(targets, hparams):
"""Prepare decoder for images."""
targets_shape = common_layers.shape_list(targets)
channels = hparams.num_channels
curr_infer_length = None
# during training, images are [batch, IMG_LEN, IMG_LEN, 3].
# At inference, they are [batch, curr_infer_length, 1, 1]
if hparams.mode == tf_estimator.ModeKeys.PREDICT:
curr_infer_length = targets_shape[1]
if hparams.block_raster_scan:
assert hparams.img_len*channels % hparams.query_shape[1] == 0
assert hparams.img_len % hparams.query_shape[0] == 0
total_block_width = hparams.img_len*channels
# Decoding is in block raster scan order. We divide the image into
# hparams.query_shape blocks and then decode each block in raster scan.
# To make that compatible with our inference pipeline, pad the target so
# that rows is a multiple of query_shape and columns is a multiple of
# hparams.img_len*channels
curr_infer_length = targets_shape[1]
block_padding_factor = total_block_width * hparams.query_shape[0]
targets = tf.pad(targets, [
[0, 0], [0, -curr_infer_length % block_padding_factor],
[0, 0], [0, 0]])
num_blocks = total_block_width // hparams.query_shape[1]
# Reshape the image to represent blocks
target_blocks = tf.reshape(
targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
hparams.query_shape[1]])
# Transpose to read the image in 2D fashion.
targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
else:
# add padding to make sure the size of targets is a multiple of img_height
# times number of channels. This is needed for positional encodings and
# for doing the RGB lookup.
padding_factor = channels * hparams.img_len
targets = tf.pad(targets, [
[0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
targets = tf.reshape(targets,
[targets_shape[0], -1, hparams.img_len, channels])
# Preprocess image
x = prepare_image(targets, hparams, name="dec_channels")
x_shape = common_layers.shape_list(x)
if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
x = common_attention.right_shift_blockwise(x, hparams.query_shape)
x = add_pos_signals(x, hparams, "dec_pos")
else:
# Add position signals
x = tf.reshape(x, [targets_shape[0],
x_shape[1]*x_shape[2], hparams.hidden_size])
x = common_layers.shift_right_3d(x)
x = tf.reshape(x, [targets_shape[0],
x_shape[1], x_shape[2], hparams.hidden_size])
x = add_pos_signals(x, hparams, "dec_pos")
x = common_layers.cast_like(x, targets)
return x, x_shape[1], x_shape[2]
def prepare_image(inputs, hparams, name=None):
"""Prepare image."""
# TODO(trandustin): This is a legacy function. Remove its usage.
del hparams, name # unused arg
return inputs
def create_output(decoder_output, rows, cols, targets, hparams):
"""Creates output from decoder output and vars.
Args:
decoder_output: Tensor of shape [batch, ...], where ... can be any rank such
that the number of elements is batch * rows * cols * hparams.hidden_size.
rows: Integer representing number of rows in a 2-D data point.
cols: Integer representing number of columns in a 2-D data point.
targets: Tensor of shape [batch, hparams.img_len, hparams.img_len,
hparams.num_channels].
hparams: HParams set.
Returns:
Tensor of shape [batch, hparams.img_len, hparams.img_len,
hparams.num_mixtures * 10] if hparams.likelihood is DMOL, otherwise
[batch, hparams.img_len, hparams.img_len, hparams.num_channels, 256].
In the special case of predict mode, it is a Tensor of rank 5.
"""
del targets # unused arg
decoded_image = postprocess_image(decoder_output, rows, cols, hparams)
batch = common_layers.shape_list(decoded_image)[0]
depth = common_layers.shape_list(decoded_image)[-1]
likelihood = getattr(hparams, "likelihood", DistributionType.CAT)
if hparams.mode == tf_estimator.ModeKeys.PREDICT:
y = tf.reshape(decoded_image, [batch, -1, 1, 1, depth])
output = y[:, :rows, :, :, :]
elif likelihood == DistributionType.CAT:
# Unpack the cols dimension of the Categorical.
channels = hparams.num_channels
output = tf.reshape(decoded_image,
[batch, rows, cols // channels, channels, depth])
else:
output = decoded_image
return output
def get_channel_embeddings(io_depth, targets, hidden_size, name="channel"):
"""Get separate embedding for each of the channels."""
targets_split = tf.split(targets, io_depth, axis=3)
rgb_embedding_var = tf.get_variable("rgb_target_emb_%s" % name,
[256 * io_depth, hidden_size])
rgb_embedding_var = tf.identity(rgb_embedding_var)
rgb_embedding_var *= float(hidden_size)**0.5
channel_target_embs = []
for i in range(io_depth):
# Adding the channel offsets to get the right embedding since the
# embedding tensor has shape 256 * io_depth, hidden_size
target_ids = tf.squeeze(targets_split[i], axis=3) + i * 256
target_embs = common_layers.gather(rgb_embedding_var, target_ids)
channel_target_embs.append(target_embs)
return tf.concat(channel_target_embs, axis=-1)
def add_pos_signals(x, hparams, name="pos_emb"):
with tf.variable_scope(name, reuse=False):
if hparams.pos == "timing":
x = common_attention.add_timing_signal_nd(x)
else:
assert hparams.pos == "emb"
x = common_attention.add_positional_embedding_nd(
x, hparams.max_length, name)
return x
================================================
FILE: tensor2tensor/layers/common_image_attention_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for common image attention utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_image_attention
from tensor2tensor.utils import hparam
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class CommonImageAttentionTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(common_image_attention.DistributionType.DMOL, 5, 50),
(common_image_attention.DistributionType.CAT, None, 256),
)
def testPostProcessImageTrainMode(self, likelihood, num_mixtures, depth):
batch = 1
rows = 8
cols = 24
hparams = hparam.HParams(
hidden_size=2,
likelihood=likelihood,
mode=tf_estimator.ModeKeys.TRAIN,
num_mixtures=num_mixtures,
)
inputs = tf.random_uniform([batch, rows, cols, hparams.hidden_size],
minval=-1., maxval=1.)
outputs = common_image_attention.postprocess_image(
inputs, rows, cols, hparams)
self.assertEqual(outputs.shape, (batch, rows, cols, depth))
@parameterized.parameters(
(common_image_attention.DistributionType.DMOL, 5, 50),
(common_image_attention.DistributionType.CAT, None, 256),
)
def testPostProcessImageInferMode(self, likelihood, num_mixtures, depth):
batch = 1
rows = 8
cols = 24
block_length = 4
block_width = 2
hparams = hparam.HParams(
block_raster_scan=True,
hidden_size=2,
likelihood=likelihood,
mode=tf_estimator.ModeKeys.PREDICT,
num_mixtures=num_mixtures,
query_shape=[block_length, block_width],
)
inputs = tf.random_uniform([batch, rows, cols, hparams.hidden_size],
minval=-1., maxval=1.)
outputs = common_image_attention.postprocess_image(
inputs, rows, cols, hparams)
num_blocks_rows = rows // block_length
num_blocks_cols = cols // block_width
self.assertEqual(outputs.shape,
(batch, num_blocks_rows, num_blocks_cols,
block_length, block_width, depth))
@parameterized.parameters(
(common_image_attention.DistributionType.DMOL, 5, 50),
(common_image_attention.DistributionType.CAT, None, 256),
)
def testCreateOutputTrainMode(self, likelihood, num_mixtures, depth):
batch = 1
height = 8
width = 8
channels = 3
rows = height
if likelihood == common_image_attention.DistributionType.CAT:
cols = channels * width
else:
cols = width
hparams = hparam.HParams(
hidden_size=2,
likelihood=likelihood,
num_channels=channels,
mode=tf_estimator.ModeKeys.TRAIN,
num_mixtures=num_mixtures,
)
decoder_output = tf.random_normal([batch, rows, cols, hparams.hidden_size])
targets = tf.random_uniform([batch, height, width, channels],
minval=-1., maxval=1.)
output = common_image_attention.create_output(
decoder_output, rows, cols, targets, hparams)
if hparams.likelihood == common_image_attention.DistributionType.CAT:
self.assertEqual(output.shape, (batch, height, width, channels, depth))
else:
self.assertEqual(output.shape, (batch, height, width, depth))
def testTransformerDecoderLayersGlobal(self):
one_hot_data = tf.constant([[[0., 1.], [1., 0.]],
[[0., 1.], [1., 0.]],
[[1., 0.], [1., 0.]]])
hparams = common_hparams.basic_params1()
hparams.hidden_size = 4
hparams.num_layers = 1
hparams.layer_prepostprocess_dropout = 0.
hparams.add_hparam("attention_key_channels", None)
hparams.add_hparam("attention_value_channels", None)
hparams.add_hparam("num_heads", 1)
hparams.add_hparam("attention_dropout", 0.)
hparams.add_hparam("shared_rel", False)
hparams.add_hparam("block_width", 1)
hparams.add_hparam("block_length", 1)
hparams.add_hparam("q_filter_width", 1)
hparams.add_hparam("kv_filter_width", 1)
hparams.add_hparam("filter_size", 16)
hparams.add_hparam("ffn_layer", "conv_hidden_relu")
hparams.add_hparam("relu_dropout", 0.)
conv_1d = tf.keras.layers.Conv1D(filters=hparams.hidden_size,
kernel_size=1,
use_bias=False)
shifted_data = tf.pad(one_hot_data, [[0, 0], [1, 0], [0, 0]])[..., :-1, :]
net = conv_1d(shifted_data)
output = common_image_attention.transformer_decoder_layers(
inputs=net,
encoder_output=None,
num_layers=hparams.num_layers,
hparams=hparams,
self_attention_bias=common_image_attention.get_self_attention_bias(net),
attention_type=common_image_attention.AttentionType.GLOBAL)
self.evaluate(tf.global_variables_initializer())
output_val = self.evaluate(output)
# The outputs for the padded dimension should be equal across all data.
self.assertAllEqual(output_val[0, 0], output_val[1, 0])
self.assertAllEqual(output_val[1, 0], output_val[2, 0])
# The first and second elements of the batch are identical, so they should
# have the same outputs for the second latent dimension as well.
self.assertAllEqual(output_val[0, 1], output_val[1, 1])
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/common_layers.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers common to multiple models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import functools
import math
from absl import logging
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.utils import contrib
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import inplace_ops
# TODO(lukaszkaiser): remove this function when not needed any more.
def layers():
"""Get the layers module good for TF 1 and TF 2 work for now."""
layers_module = None
try:
layers_module = tf.layers
except AttributeError:
logging.info("Cannot access tf.layers, trying TF2 layers.")
try:
from tensorflow.python import tf2 # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
if tf2.enabled():
logging.info("Running in V2 mode, using Keras layers.")
layers_module = tf.keras.layers
except ImportError:
pass
return layers_module
@function.Defun(
python_grad_func=lambda x, dy: tf.convert_to_tensor(dy),
shape_func=lambda op: [op.inputs[0].get_shape()])
def convert_gradient_to_tensor(x):
"""Identity operation whose gradient is converted to a `Tensor`.
Currently, the gradient to `tf.concat` is particularly expensive to
compute if dy is an `IndexedSlices` (a lack of GPU implementation
forces the gradient operation onto CPU). This situation occurs when
the output of the `tf.concat` is eventually passed to `tf.gather`.
It is sometimes faster to convert the gradient to a `Tensor`, so as
to get the cheaper gradient for `tf.concat`. To do this, replace
`tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`.
Args:
x: A `Tensor`.
Returns:
The input `Tensor`.
"""
return x
def is_xla_compiled():
"""Whether we are building graph that will be compiled by XLA.
This checks whether the code is executing within an XLA context.
If True, model authors should ensure the graph they build is compilable by
XLA. Specifically, they should ensure that all ops have XLA implementations
and that all shapes are statically known.
Returns:
bool, whether the current graph will be compiled for XLA.
"""
ctxt = tf.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
return control_flow_util.GetContainingXLAContext(ctxt) is not None
def to_float(x):
"""Cast x to float; created because tf.to_float is deprecated."""
return tf.cast(x, tf.float32)
def dropout_with_broadcast_dims(x, keep_prob, broadcast_dims=None, **kwargs):
"""Like tf.nn.dropout but takes broadcast_dims instead of noise_shape.
Instead of specifying noise_shape, this function takes broadcast_dims -
a list of dimension numbers in which noise_shape should be 1. The random
keep/drop tensor has dimensionality 1 along these dimensions.
Args:
x: a floating point tensor.
keep_prob: A scalar Tensor with the same type as x.
The probability that each element is kept.
broadcast_dims: an optional list of integers
the dimensions along which to broadcast the keep/drop flags.
**kwargs: keyword arguments to tf.nn.dropout other than "noise_shape".
Returns:
Tensor of the same shape as x.
"""
assert "noise_shape" not in kwargs
if broadcast_dims:
shape = tf.shape(x)
ndims = len(x.get_shape())
# Allow dimensions like "-1" as well.
broadcast_dims = [dim + ndims if dim < 0 else dim for dim in broadcast_dims]
kwargs["noise_shape"] = [
1 if i in broadcast_dims else shape[i] for i in range(ndims)
]
return tf.nn.dropout(x, keep_prob, **kwargs)
def comma_separated_string_to_integer_list(s):
return [int(i) for i in s.split(",") if i]
def saturating_sigmoid(x):
"""Saturating sigmoid: 1.2 * sigmoid(x) - 0.1 cut to [0, 1]."""
with tf.name_scope("saturating_sigmoid", values=[x]):
y = tf.sigmoid(x)
return tf.minimum(1.0, tf.maximum(0.0, 1.2 * y - 0.1))
def hard_sigmoid(x, saturation_limit=0.9):
saturation_cost = tf.reduce_mean(tf.nn.relu(tf.abs(x) - saturation_limit))
x_shifted = 0.5 * x + 0.5
return tf.minimum(1.0, tf.nn.relu(x_shifted)), saturation_cost
def hard_tanh(x, saturation_limit=0.9):
saturation_cost = tf.reduce_mean(tf.nn.relu(tf.abs(x) - saturation_limit))
return tf.minimum(1.0, tf.maximum(x, -1.0)), saturation_cost
def inverse_exp_decay(max_step, min_value=0.01, step=None):
"""Inverse-decay exponentially from min_value to 1.0 reached at max_step."""
inv_base = tf.exp(tf.log(min_value) / float(max_step))
if step is None:
step = tf.train.get_global_step()
if step is None:
return 1.0
step = to_float(step)
return inv_base**tf.maximum(float(max_step) - step, 0.0)
def inverse_lin_decay(max_step, min_value=0.01, step=None):
"""Inverse-decay linearly from min_value to 1.0 reached at max_step."""
if step is None:
step = tf.train.get_global_step()
if step is None:
return 1.0
step = to_float(step)
progress = tf.minimum(step / float(max_step), 1.0)
return progress * (1.0 - min_value) + min_value
def inverse_sigmoid_decay(max_step, min_value=0.01, step=None):
"""Inverse-decay linearly from min_value to 1.0 reached at max_step."""
if step is None:
step = tf.train.get_global_step()
if step is None:
return 1.0
step = to_float(step)
def sigmoid(x):
return 1 / (1 + tf.exp(-x))
def inv_sigmoid(y):
return tf.log(y / (1 - y))
assert min_value > 0, (
"sigmoid's output is always >0 and <1. min_value must respect "
"these bounds for interpolation to work.")
assert min_value < 0.5, "Must choose min_value on the left half of sigmoid."
# Find
# x s.t. sigmoid(x ) = y_min and
# x' s.t. sigmoid(x') = y_max
# We will map [0, max_step] to [x_min, x_max].
y_min = min_value
y_max = 1.0 - min_value
x_min = inv_sigmoid(y_min)
x_max = inv_sigmoid(y_max)
x = tf.minimum(step / float(max_step), 1.0) # [0, 1]
x = x_min + (x_max - x_min) * x # [x_min, x_max]
y = sigmoid(x) # [y_min, y_max]
y = (y - y_min) / (y_max - y_min) # [0, 1]
y = y * (1.0 - y_min) # [0, 1-y_min]
y += y_min # [y_min, 1]
return y
def shakeshake2_py(x, y, equal=False, individual=False):
"""The shake-shake sum of 2 tensors, python version."""
if equal:
alpha = 0.5
elif individual:
alpha = tf.random_uniform(tf.get_shape(x)[:1])
else:
alpha = tf.random_uniform([])
return alpha * x + (1.0 - alpha) * y
@function.Defun()
def shakeshake2_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx
@function.Defun()
def shakeshake2_indiv_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2, individual=True)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx
@function.Defun()
def shakeshake2_equal_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2, equal=True)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx
@function.Defun(grad_func=shakeshake2_grad)
def shakeshake2(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
return shakeshake2_py(x1, x2)
@function.Defun(grad_func=shakeshake2_indiv_grad)
def shakeshake2_indiv(x1, x2):
return shakeshake2_py(x1, x2, individual=True)
@function.Defun(grad_func=shakeshake2_equal_grad)
def shakeshake2_eqgrad(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
return shakeshake2_py(x1, x2)
def shakeshake(xs, equal_grad=False):
"""Multi-argument shake-shake, currently approximated by sums of 2."""
if len(xs) == 1:
return xs[0]
div = (len(xs) + 1) // 2
arg1 = shakeshake(xs[:div], equal_grad=equal_grad)
arg2 = shakeshake(xs[div:], equal_grad=equal_grad)
if equal_grad:
return shakeshake2_eqgrad(arg1, arg2)
return shakeshake2(arg1, arg2)
def convert_rgb_to_real(x):
"""Conversion of pixel values to real numbers."""
with tf.name_scope("rgb_to_real", values=[x]):
x = to_float(x)
x /= 255.0
return x
def convert_rgb_to_symmetric_real(x):
"""Conversion of pixel values to real numbers."""
with tf.name_scope("rgb_to_real", values=[x]):
x = to_float(x)
# Convert each pixel intensity in [0, 1, 2, ..., 255] into a real number in
# the range [-1, 1].
x = (x / 127.5) - 1
return x
def convert_real_to_rgb(x):
"""Conversion of real numbers to pixel values."""
with tf.name_scope("real_to_rgb", values=[x]):
x *= 255.0
return x
def expand_squeeze_to_nd(x, n, squeeze_dim=2, expand_dim=-1):
"""Make x n-d with squeeze and expand_dims."""
if len(x.shape) > n:
while len(x.shape) != n:
x = tf.squeeze(x, [squeeze_dim])
else:
while len(x.shape) != n:
x = tf.expand_dims(x, expand_dim)
return x
def standardize_images(x):
"""Image standardization on batches and videos."""
with tf.name_scope("standardize_images", values=[x]):
x_shape = shape_list(x)
x = to_float(tf.reshape(x, [-1] + x_shape[-3:]))
x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
x_variance = tf.reduce_mean(
tf.squared_difference(x, x_mean), axis=[1, 2], keepdims=True)
num_pixels = to_float(x_shape[-2] * x_shape[-3])
x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))
return tf.reshape(x, x_shape)
def flatten4d3d(x):
"""Flatten a 4d-tensor into a 3d-tensor by joining width and height."""
xshape = shape_list(x)
result = tf.reshape(x, [xshape[0], xshape[1] * xshape[2], xshape[3]])
return result
# TODO(noam): remove this function after TPUs do gather faster.
def gather(params, indices, dtype=tf.float32):
"""Version of tf.gather that works faster on tpu."""
if not is_xla_compiled():
return tf.gather(params, indices)
vocab_size = params.get_shape().as_list()[0]
indices_flat = tf.reshape(indices, [-1])
out = tf.matmul(tf.one_hot(indices_flat, vocab_size, dtype=dtype), params)
out = reshape_like(out, tf.expand_dims(indices, -1))
return out
# TODO(noam): remove this function after TPUs do cumsum faster.
def cumsum(x, axis=0, exclusive=False):
"""TPU hack for tf.cumsum.
This is equivalent to tf.cumsum and is faster on TPU as of 04/2018 unless
the axis dimension is very large.
Args:
x: a Tensor
axis: an integer
exclusive: a boolean
Returns:
Tensor of the same shape as x.
"""
if not is_xla_compiled():
return tf.cumsum(x, axis=axis, exclusive=exclusive)
x_shape = shape_list(x)
rank = len(x_shape)
length = x_shape[axis]
my_range = tf.range(length)
comparator = tf.less if exclusive else tf.less_equal
mask = tf.cast(
comparator(tf.expand_dims(my_range, 1), tf.expand_dims(my_range, 0)),
x.dtype)
ret = tf.tensordot(x, mask, axes=[[axis], [0]])
if axis != rank - 1:
ret = tf.transpose(
ret,
list(range(axis)) + [rank - 1] + list(range(axis, rank - 1)))
return ret
def dropout_no_scaling(x, keep_prob):
"""Like tf.nn.dropout, but does not scale up. Works on integers also.
Args:
x: a Tensor
keep_prob: a floating point number
Returns:
Tensor of the same shape as x.
"""
if keep_prob == 1.0:
return x
mask = tf.less(tf.random_uniform(tf.shape(x)), keep_prob)
return x * cast_like(mask, x)
def embedding(x,
vocab_size,
dense_size,
name=None,
reuse=None,
multiplier=1.0,
symbol_dropout_rate=0.0,
embedding_var=None,
dtype=tf.float32):
"""Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
with tf.variable_scope(
name, default_name="embedding", values=[x], reuse=reuse, dtype=dtype):
if embedding_var is None:
embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
# On the backwards pass, we want to convert the gradient from
# an indexed-slices to a regular tensor before sending it back to the
# parameter server. This avoids excess computation on the parameter server.
if not tf.executing_eagerly():
embedding_var = convert_gradient_to_tensor(embedding_var)
x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
emb_x = gather(embedding_var, x, dtype)
if multiplier != 1.0:
emb_x *= multiplier
static_shape = emb_x.shape.as_list()
if len(static_shape) < 5:
return emb_x
assert len(static_shape) == 5
# If we had an extra channel dimension, assume it's 1, i.e. shape[3] == 1.
return tf.squeeze(emb_x, 3)
def shift_right(x, pad_value=None):
"""Shift the second dimension of x right by one."""
if pad_value is None:
shifted_targets = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])[:, :-1, :, :]
else:
shifted_targets = tf.concat([pad_value, x], axis=1)[:, :-1, :, :]
return shifted_targets
def shift_right_3d(x, pad_value=None):
"""Shift the second dimension of x right by one."""
if pad_value is None:
shifted_targets = tf.pad(x, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
else:
shifted_targets = tf.concat([pad_value, x], axis=1)[:, :-1, :]
return shifted_targets
def shift_right_2d(x, pad_value=None):
"""Shift the second dimension of x right by one."""
if pad_value is None:
shifted_targets = tf.pad(x, [[0, 0], [1, 0]])[:, :-1]
else:
shifted_targets = tf.concat([pad_value, x], axis=1)[:, :-1]
return shifted_targets
def conv_stride2_multistep(x, nbr_steps, output_filters, name=None, reuse=None):
"""Use a strided convolution to downsample x by 2, `nbr_steps` times.
We use stride and filter size 2 to avoid the checkerboard problem of deconvs.
As detailed in http://distill.pub/2016/deconv-checkerboard/.
Args:
x: a `Tensor` with shape `[batch, spatial, depth]` or
`[batch, spatial_1, spatial_2, depth]`
nbr_steps: number of halving downsample rounds to apply
output_filters: an int specifying the filter count for the convolutions
name: a string
reuse: a boolean
Returns:
a `Tensor` with shape `[batch, spatial / (2**nbr_steps), output_filters]` or
`[batch, spatial_1 / (2**nbr_steps), spatial_2 / (2**nbr_steps),
output_filters]`
"""
with tf.variable_scope(
name, default_name="conv_stride2_multistep", values=[x], reuse=reuse):
if nbr_steps == 0:
out = conv(x, output_filters, (1, 1))
return out, [out]
hidden_layers = [x]
for i in range(nbr_steps):
hidden_layers.append(
conv(
hidden_layers[-1],
output_filters, (2, 2),
strides=2,
activation=tf.nn.relu,
name="conv" + str(i)))
return hidden_layers[-1], hidden_layers
def deconv_stride2_multistep(x,
nbr_steps,
output_filters,
name=None,
reuse=None):
"""Use a deconvolution to upsample x by 2**`nbr_steps`.
Args:
x: a `Tensor` with shape `[batch, spatial, depth]` or
`[batch, spatial_1, spatial_2, depth]`
nbr_steps: an int specifying the number of doubling upsample rounds to
apply.
output_filters: an int specifying the filter count for the deconvolutions
name: a string
reuse: a boolean
Returns:
a `Tensor` with shape `[batch, spatial * (2**nbr_steps), output_filters]` or
`[batch, spatial_1 * (2**nbr_steps), spatial_2 * (2**nbr_steps),
output_filters]`
"""
with tf.variable_scope(
name, default_name="deconv_stride2_multistep", values=[x], reuse=reuse):
def deconv1d(cur, i):
cur_shape = shape_list(cur)
thicker = conv(
cur,
output_filters * 2, (1, 1),
padding="SAME",
activation=tf.nn.relu,
name="deconv1d" + str(i))
return tf.reshape(thicker,
[cur_shape[0], cur_shape[1] * 2, 1, output_filters])
def deconv2d(cur, i):
thicker = conv(
cur,
output_filters * 4, (1, 1),
padding="SAME",
activation=tf.nn.relu,
name="deconv2d" + str(i))
return tf.depth_to_space(thicker, 2)
cur = x
for i in range(nbr_steps):
if cur.get_shape()[2] == 1:
cur = deconv1d(cur, i)
else:
cur_dim = shape_list(cur)[2]
if isinstance(cur_dim, int):
if cur_dim == 1:
cur = deconv1d(cur, i)
else:
cur = deconv2d(cur, i)
else:
cur = tf.cond(
tf.equal(cur_dim, 1),
lambda idx=i: deconv1d(cur, idx),
lambda idx=i: deconv2d(cur, idx))
return cur
def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs):
"""Conditional conv_fn making kernel 1d or 2d depending on inputs shape."""
static_shape = inputs.get_shape()
if not static_shape or len(static_shape) != 4:
raise ValueError("Inputs to conv must have statically known rank 4. "
"Shape: " + str(static_shape))
# Add support for left padding.
if kwargs.get("padding") == "LEFT":
dilation_rate = (1, 1)
if "dilation_rate" in kwargs:
dilation_rate = kwargs["dilation_rate"]
assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1
height_padding = 2 * (kernel_size[0] // 2) * dilation_rate[0]
cond_padding = tf.cond(
tf.equal(shape_list(inputs)[2], 1), lambda: tf.constant(0),
lambda: tf.constant(2 * (kernel_size[1] // 2) * dilation_rate[1]))
width_padding = 0 if static_shape[2] == 1 else cond_padding
padding = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]]
inputs = tf.pad(inputs, padding)
# Set middle two dimensions to None to prevent convolution from complaining
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
kwargs["padding"] = "VALID"
def conv2d_kernel(kernel_size_arg, name_suffix):
"""Call conv2d but add suffix to name."""
name = "{}_{}".format(kwargs.get("name", "conv"), name_suffix)
original_name = kwargs.pop("name", None)
original_force2d = kwargs.pop("force2d", None)
result = conv_fn(inputs, filters, kernel_size_arg, name=name, **kwargs)
if original_name is not None:
kwargs["name"] = original_name # Restore for other calls.
if original_force2d is not None:
kwargs["force2d"] = original_force2d
return result
return conv2d_kernel(kernel_size, "single")
def conv(inputs, filters, kernel_size, dilation_rate=(1, 1), **kwargs):
def _conv2d(x, *args, **kwargs):
return layers().Conv2D(*args, **kwargs)(x)
return conv_internal(
_conv2d,
inputs,
filters,
kernel_size,
dilation_rate=dilation_rate,
**kwargs)
def conv1d(inputs, filters, kernel_size, dilation_rate=1, **kwargs):
return tf.squeeze(
conv(tf.expand_dims(inputs, 2), filters, (kernel_size, 1),
dilation_rate=(dilation_rate, 1), **kwargs),
2)
def separable_conv(inputs, filters, kernel_size, **kwargs):
def _sep_conv2d(x, *args, **kwargs):
return layers().SeparableConv2D(*args, **kwargs)(x)
return conv_internal(_sep_conv2d, inputs, filters, kernel_size, **kwargs)
def subseparable_conv(inputs, filters, kernel_size, **kwargs):
"""Sub-separable convolution. If separability == 0 it's a separable_conv."""
def conv_fn(inputs, filters, kernel_size, **kwargs):
"""Sub-separable convolution, splits into separability-many blocks."""
separability = None
if "separability" in kwargs:
separability = kwargs.pop("separability")
if separability:
parts = []
abs_sep = separability if separability > 0 else -1 * separability
for split_idx, split in enumerate(tf.split(inputs, abs_sep, axis=3)):
with tf.variable_scope("part_%d" % split_idx):
if separability > 0:
parts.append(
layers().Conv2D(filters // separability, kernel_size,
**kwargs)(split))
else:
parts.append(
layers().SeparableConv2D(filters // abs_sep,
kernel_size, **kwargs)(split))
if separability > 1:
result = layers().Conv2D(filters, (1, 1))(tf.concat(parts, axis=3))
elif abs_sep == 1: # If we have just one block, return it.
assert len(parts) == 1
result = parts[0]
else:
result = tf.concat(parts, axis=3)
else:
result = layers().SeparableConv2D(filters, kernel_size,
**kwargs)(inputs)
if separability is not None:
kwargs["separability"] = separability
return result
return conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs)
def tpu_conv1d(inputs, filters, kernel_size, padding="SAME", name="tpu_conv1d"):
"""Version of conv1d that works on TPU (as of 11/2017).
Args:
inputs: a Tensor with shape [batch, length, input_depth].
filters: an integer.
kernel_size: an integer.
padding: a string - "SAME" or "LEFT".
name: a string.
Returns:
a Tensor with shape [batch, length, filters].
"""
if kernel_size == 1:
return dense(inputs, filters, name=name, use_bias=True)
if padding == "SAME":
assert kernel_size % 2 == 1
first_offset = -((kernel_size - 1) // 2)
else:
assert padding == "LEFT"
first_offset = -(kernel_size - 1)
last_offset = first_offset + kernel_size - 1
results = []
padded = tf.pad(inputs, [[0, 0], [-first_offset, last_offset], [0, 0]])
for i in range(kernel_size):
shifted = tf.slice(padded, [0, i, 0], tf.shape(inputs)) if i else inputs
shifted.set_shape(inputs.get_shape())
results.append(
dense(shifted, filters, use_bias=(i == 0), name=name + "_%d" % i))
ret = tf.add_n(results)
ret *= kernel_size**-0.5
return ret
def layer_norm_vars(filters):
"""Create Variables for layer norm."""
scale = tf.get_variable(
"layer_norm_scale", [filters], initializer=tf.ones_initializer())
bias = tf.get_variable(
"layer_norm_bias", [filters], initializer=tf.zeros_initializer())
return scale, bias
def layer_norm_compute(x, epsilon, scale, bias, layer_collection=None):
"""Layer norm raw computation."""
# Save these before they get converted to tensors by the casting below
params = (scale, bias)
epsilon, scale, bias = [cast_like(t, x) for t in [epsilon, scale, bias]]
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(
tf.squared_difference(x, mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
output = norm_x * scale + bias
return output
def layer_norm(x,
filters=None,
epsilon=1e-6,
name=None,
reuse=None,
layer_collection=None):
"""Layer normalize the tensor x, averaging over the last dimension."""
if filters is None:
filters = shape_list(x)[-1]
with tf.variable_scope(
name, default_name="layer_norm", values=[x], reuse=reuse):
scale, bias = layer_norm_vars(filters)
return layer_norm_compute(x, epsilon, scale, bias,
layer_collection=layer_collection)
def group_norm(x, filters=None, num_groups=8, epsilon=1e-5):
"""Group normalization as in https://arxiv.org/abs/1803.08494."""
x_shape = shape_list(x)
if filters is None:
filters = x_shape[-1]
assert len(x_shape) == 4
assert filters % num_groups == 0
# Prepare variables.
scale = tf.get_variable(
"group_norm_scale", [filters], initializer=tf.ones_initializer())
bias = tf.get_variable(
"group_norm_bias", [filters], initializer=tf.zeros_initializer())
epsilon, scale, bias = [cast_like(t, x) for t in [epsilon, scale, bias]]
# Reshape and compute group norm.
x = tf.reshape(x, x_shape[:-1] + [num_groups, filters // num_groups])
# Calculate mean and variance on heights, width, channels (not groups).
mean, variance = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
return tf.reshape(norm_x, x_shape) * scale + bias
def noam_norm(x, epsilon=1.0, name=None):
"""One version of layer normalization."""
with tf.name_scope(name, default_name="noam_norm", values=[x]):
shape = x.get_shape()
ndims = len(shape)
return (tf.nn.l2_normalize(x, ndims - 1, epsilon=epsilon) * tf.sqrt(
to_float(shape[-1])))
def l2_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
"""Layer normalization with l2 norm."""
if filters is None:
filters = shape_list(x)[-1]
with tf.variable_scope(name, default_name="l2_norm", values=[x], reuse=reuse):
scale = tf.get_variable(
"l2_norm_scale", [filters], initializer=tf.ones_initializer())
bias = tf.get_variable(
"l2_norm_bias", [filters], initializer=tf.zeros_initializer())
epsilon, scale, bias = [cast_like(t, x) for t in [epsilon, scale, bias]]
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
l2norm = tf.reduce_sum(
tf.squared_difference(x, mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.rsqrt(l2norm + epsilon)
return norm_x * scale + bias
def apply_spectral_norm(x):
"""Normalizes x using the spectral norm.
The implementation follows Algorithm 1 of
https://arxiv.org/abs/1802.05957. If x is not a 2-D Tensor, then it is
reshaped such that the number of channels (last-dimension) is the same.
Args:
x: Tensor with the last dimension equal to the number of filters.
Returns:
x: Tensor with the same shape as x normalized by the spectral norm.
assign_op: Op to be run after every step to update the vector "u".
"""
weights_shape = shape_list(x)
other, num_filters = tf.reduce_prod(weights_shape[:-1]), weights_shape[-1]
# Reshape into a 2-D matrix with outer size num_filters.
weights_2d = tf.reshape(x, (other, num_filters))
# v = Wu / ||W u||
with tf.variable_scope("u", reuse=tf.AUTO_REUSE):
u = tf.get_variable(
"u", [num_filters, 1],
initializer=tf.truncated_normal_initializer(),
trainable=False)
v = tf.nn.l2_normalize(tf.matmul(weights_2d, u))
# u_new = vW / ||v W||
u_new = tf.nn.l2_normalize(tf.matmul(tf.transpose(v), weights_2d))
# s = v*W*u
spectral_norm = tf.squeeze(
tf.matmul(tf.transpose(v), tf.matmul(weights_2d, tf.transpose(u_new))))
# set u equal to u_new in the next iteration.
assign_op = tf.assign(u, tf.transpose(u_new))
return tf.divide(x, spectral_norm), assign_op
def apply_norm(x, norm_type, depth, epsilon, layer_collection=None):
"""Apply Normalization."""
if layer_collection is not None:
assert norm_type == "layer"
if norm_type == "layer":
return layer_norm(
x, filters=depth, epsilon=epsilon, layer_collection=layer_collection)
if norm_type == "group":
return group_norm(x, filters=depth, epsilon=epsilon)
if norm_type == "batch":
return layers().BatchNormalization(epsilon=epsilon)(x)
if norm_type == "noam":
return noam_norm(x, epsilon)
if norm_type == "l2":
return l2_norm(x, filters=depth, epsilon=epsilon)
if norm_type == "none":
return x
raise ValueError("Parameter normalizer_fn must be one of: 'layer', 'batch',"
"'noam', 'lr', 'none'.")
def zero_add(previous_value, x, name=None, reuse=None):
"""Resnet connection with zero initialization.
Another type of resnet connection which returns previous_value + gamma * x.
gamma is a trainable scalar and initialized with zero. It is useful when a
module is plugged into a trained model and we want to make sure it matches the
original model's performance.
Args:
previous_value: A tensor.
x: A tensor.
name: name of variable scope; defaults to zero_add.
reuse: reuse scope.
Returns:
previous_value + gamma * x.
"""
with tf.variable_scope(name, default_name="zero_add", reuse=reuse):
gamma = tf.get_variable("gamma", (), initializer=tf.zeros_initializer())
return previous_value + gamma * x
def layer_prepostprocess(previous_value,
x,
sequence,
dropout_rate,
norm_type,
depth,
epsilon,
default_name,
name=None,
dropout_broadcast_dims=None,
layer_collection=None):
"""Apply a sequence of functions to the input or output of a layer.
The sequence is specified as a string which may contain the following
characters:
a: add previous_value
n: apply normalization
d: apply dropout
z: zero add
For example, if sequence=="dna", then the output is
previous_value + normalize(dropout(x))
Args:
previous_value: A Tensor, to be added as a residual connection ('a')
x: A Tensor to be transformed.
sequence: a string.
dropout_rate: a float
norm_type: a string (see apply_norm())
depth: an integer (size of last dimension of x).
epsilon: a float (parameter for normalization)
default_name: a string
name: a string
dropout_broadcast_dims: an optional list of integers less than 3
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
Returns:
a Tensor
"""
with tf.variable_scope(name, default_name=default_name):
if sequence == "none":
return x
for c in sequence:
if c == "a":
x += previous_value
elif c == "z":
x = zero_add(previous_value, x)
elif c == "n":
x = apply_norm(
x, norm_type, depth, epsilon, layer_collection=layer_collection)
else:
assert c == "d", ("Unknown sequence step %s" % c)
x = dropout_with_broadcast_dims(
x, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
return x
def layer_preprocess(layer_input, hparams, layer_collection=None):
"""Apply layer preprocessing.
See layer_prepostprocess() for details.
A hyperparameters object is passed for convenience. The hyperparameters
that may be used are:
layer_preprocess_sequence
layer_prepostprocess_dropout
norm_type
hidden_size
norm_epsilon
Args:
layer_input: a Tensor
hparams: a hyperparameters object.
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
Returns:
a Tensor
"""
assert "a" not in hparams.layer_preprocess_sequence, (
"No residual connections allowed in hparams.layer_preprocess_sequence")
assert "z" not in hparams.layer_preprocess_sequence, (
"No residual connections allowed in hparams.layer_preprocess_sequence")
return layer_prepostprocess(
None,
layer_input,
sequence=hparams.layer_preprocess_sequence,
dropout_rate=hparams.layer_prepostprocess_dropout,
norm_type=hparams.norm_type,
depth=None,
epsilon=hparams.norm_epsilon,
dropout_broadcast_dims=comma_separated_string_to_integer_list(
getattr(hparams, "layer_prepostprocess_dropout_broadcast_dims", "")),
default_name="layer_prepostprocess",
layer_collection=layer_collection)
def layer_postprocess(layer_input, layer_output, hparams):
"""Apply layer postprocessing.
See layer_prepostprocess() for details.
A hyperparameters object is passed for convenience. The hyperparameters
that may be used are:
layer_postprocess_sequence
layer_prepostprocess_dropout
norm_type
hidden_size
norm_epsilon
Args:
layer_input: a Tensor
layer_output: a Tensor
hparams: a hyperparameters object.
Returns:
a Tensor
"""
return layer_prepostprocess(
layer_input,
layer_output,
sequence=hparams.layer_postprocess_sequence,
dropout_rate=hparams.layer_prepostprocess_dropout,
norm_type=hparams.norm_type,
depth=None,
epsilon=hparams.norm_epsilon,
dropout_broadcast_dims=comma_separated_string_to_integer_list(
getattr(hparams, "layer_prepostprocess_dropout_broadcast_dims", "")),
default_name="layer_postprocess")
def conv_block_internal(conv_fn,
inputs,
filters,
dilation_rates_and_kernel_sizes,
first_relu=True,
use_elu=False,
separabilities=None,
**kwargs):
"""A block of convolutions.
Args:
conv_fn: convolution function, e.g. conv or separable_conv.
inputs: a Tensor
filters: an Integer
dilation_rates_and_kernel_sizes: a list of tuples (dilation, (k_w, k_h))
first_relu: whether to do a relu at start (defaults to True)
use_elu: whether to use ELUs instead of ReLUs (defaults to False)
separabilities: list of separability factors (per-layer).
**kwargs: additional arguments (e.g., pooling)
Returns:
a Tensor.
"""
name = kwargs.pop("name") if "name" in kwargs else None
mask = kwargs.pop("mask") if "mask" in kwargs else None
# Usage for normalize_fn kwarg:
# if not specified, use layer norm
# if given normalize_fn=None, don't use any normalization
# if given normalize_fn=norm, use the specified norm function
use_layer_norm = "normalizer_fn" not in kwargs
norm = kwargs.pop("normalizer_fn", None)
use_normalizer_fn = use_layer_norm or norm
if use_layer_norm:
norm = lambda x, name: layer_norm(x, filters, name=name)
with tf.variable_scope(name, "conv_block", [inputs]):
cur, counter = inputs, -1
for dilation_rate, kernel_size in dilation_rates_and_kernel_sizes:
counter += 1
if first_relu or counter > 0:
cur = tf.nn.elu(cur) if use_elu else tf.nn.relu(cur)
if mask is not None:
cur *= mask
if separabilities:
cur = conv_fn(
cur,
filters,
kernel_size,
dilation_rate=dilation_rate,
name="conv_block_%d" % counter,
use_bias=norm is None,
separability=separabilities[counter],
**kwargs)
else:
cur = conv_fn(
cur,
filters,
kernel_size,
dilation_rate=dilation_rate,
name="conv_block_%d" % counter,
use_bias=norm is None,
**kwargs)
if use_normalizer_fn:
cur = norm(cur, name="conv_block_norm_%d" % counter)
return cur
def conv_block(inputs, filters, dilation_rates_and_kernel_sizes, **kwargs):
"""A block of standard 2d convolutions."""
return conv_block_internal(conv, inputs, filters,
dilation_rates_and_kernel_sizes, **kwargs)
def conv1d_block(inputs, filters, dilation_rates_and_kernel_sizes, **kwargs):
"""A block of standard 1d convolutions."""
return conv_block_internal(conv1d, inputs, filters,
dilation_rates_and_kernel_sizes, **kwargs)
def separable_conv_block(inputs, filters, dilation_rates_and_kernel_sizes,
**kwargs):
"""A block of separable convolutions."""
return conv_block_internal(separable_conv, inputs, filters,
dilation_rates_and_kernel_sizes, **kwargs)
def subseparable_conv_block(inputs, filters, dilation_rates_and_kernel_sizes,
**kwargs):
"""A block of separable convolutions."""
return conv_block_internal(subseparable_conv, inputs, filters,
dilation_rates_and_kernel_sizes, **kwargs)
def pool(inputs, window_size, pooling_type, padding, strides=(1, 1)):
"""Pooling (supports "LEFT")."""
with tf.name_scope("pool", values=[inputs]):
static_shape = inputs.get_shape()
if not static_shape or len(static_shape) != 4:
raise ValueError("Inputs to conv must have statically known rank 4.")
# Add support for left padding.
if padding == "LEFT":
assert window_size[0] % 2 == 1 and window_size[1] % 2 == 1
if len(static_shape) == 3:
width_padding = 2 * (window_size[1] // 2)
padding_ = [[0, 0], [width_padding, 0], [0, 0]]
else:
height_padding = 2 * (window_size[0] // 2)
cond_padding = tf.cond(
tf.equal(shape_list(inputs)[2], 1), lambda: tf.constant(0),
lambda: tf.constant(2 * (window_size[1] // 2)))
width_padding = 0 if static_shape[2] == 1 else cond_padding
padding_ = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]]
inputs = tf.pad(inputs, padding_)
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
padding = "VALID"
return tf.nn.pool(inputs, window_size, pooling_type, padding, strides=strides)
def conv_block_downsample(x,
kernel,
strides,
padding,
separability=0,
name=None,
reuse=None):
"""Implements a downwards-striding conv block, like Xception exit flow."""
with tf.variable_scope(
name, default_name="conv_block_downsample", values=[x], reuse=reuse):
hidden_size = int(x.get_shape()[-1])
res = conv_block(
x,
int(1.25 * hidden_size), [((1, 1), kernel)],
padding=padding,
strides=strides,
name="res_conv")
x = subseparable_conv_block(
x,
hidden_size, [((1, 1), kernel)],
padding=padding,
separability=separability,
name="conv0")
x = subseparable_conv_block(
x,
int(1.25 * hidden_size), [((1, 1), kernel)],
padding=padding,
separability=separability,
name="conv1")
x = pool(x, kernel, "MAX", padding, strides=strides)
x += res
x = subseparable_conv_block(
x,
2 * hidden_size, [((1, 1), kernel)],
first_relu=False,
padding=padding,
separability=separability,
name="conv2")
x = subseparable_conv_block(
x,
int(2.5 * hidden_size), [((1, 1), kernel)],
padding=padding,
separability=separability,
name="conv3")
return x
def get_timing_signal(length,
min_timescale=1,
max_timescale=1e4,
num_timescales=16):
"""Create Tensor of sinusoids of different frequencies.
Args:
length: Length of the Tensor to create, i.e. Number of steps.
min_timescale: a float
max_timescale: a float
num_timescales: an int
Returns:
Tensor of shape (length, 2*num_timescales)
"""
positions = to_float(tf.range(length))
log_timescale_increment = (
math.log(max_timescale / min_timescale) / (num_timescales - 1))
inv_timescales = min_timescale * tf.exp(
to_float(tf.range(num_timescales)) * -log_timescale_increment)
scaled_time = tf.expand_dims(positions, 1) * tf.expand_dims(inv_timescales, 0)
return tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
def add_timing_signal(x, min_timescale=1, max_timescale=1e4, num_timescales=16):
"""Adds a bunch of sinusoids of different frequencies to a Tensor.
This allows attention to learn to use absolute and relative positions.
The timing signal should be added to some precursor of both the source
and the target of the attention.
The use of relative position is possible because sin(x+y) and cos(x+y) can be
expressed in terms of y, sin(x) and cos(x).
In particular, we use a geometric sequence of timescales starting with
min_timescale and ending with max_timescale. For each timescale, we
generate the two sinusoidal signals sin(timestep/timescale) and
cos(timestep/timescale). All of these sinusoids are concatenated in
the depth dimension, padded with zeros to be the same depth as the input,
and added into input.
Args:
x: a Tensor with shape [?, length, ?, depth]
min_timescale: a float
max_timescale: a float
num_timescales: an int <= depth/2
Returns:
a Tensor the same shape as x.
"""
length = shape_list(x)[1]
depth = shape_list(x)[3]
signal = get_timing_signal(length, min_timescale, max_timescale,
num_timescales)
padded_signal = tf.pad(signal, [[0, 0], [0, depth - 2 * num_timescales]])
return x + tf.reshape(padded_signal, [1, length, 1, depth])
def mask_from_embedding(emb):
"""Input embeddings -> padding mask.
We have hacked symbol_modality to return all-zero embeddings for padding.
Returns a mask with 0.0 in the padding positions and 1.0 elsewhere.
Args:
emb: a Tensor with shape [batch, width, height, depth].
Returns:
a 0.0/1.0 Tensor with shape [batch, width, height, 1].
"""
return weights_nonzero(tf.reduce_sum(tf.abs(emb), axis=3, keepdims=True))
def length_from_embedding(emb):
"""Compute the length of each sequence in the batch.
Args:
emb: a sequence embedding Tensor with shape [batch, max_time, 1, depth].
Returns:
a Tensor with shape [batch].
"""
return tf.cast(tf.reduce_sum(mask_from_embedding(emb), [1, 2, 3]), tf.int32)
def mask_pos_gt(source_length, target_length):
"""A mask with 1.0 wherever source_pos > target_pos and 0.0 elsewhere.
Args:
source_length: an integer
target_length: an integer
Returns:
a Tensor with shape [1, target_length, source_length]
"""
return tf.expand_dims(
tf.cast(tf.greater(tf.expand_dims(tf.range(target_length), axis=0),
tf.expand_dims(tf.range(source_length), axis=1)),
dtype=tf.float32), axis=0)
def mask_leq(target_length, source_length):
"""A mask with 1.0 wherever source_pos <= target_pos and 0.0 elsewhere.
Args:
target_length: an integer
source_length: an integer
Returns:
a Tensor with shape [1, target_length, source_length]
"""
return ones_matrix_band_part(
target_length,
source_length,
-1,
0,
out_shape=[1, target_length, source_length])
def mask_pos_lt(source_length, target_length):
"""A mask with 1.0 wherever source_pos < target_pos and 0.0 elsewhere.
Args:
source_length: an integer
target_length: an integer
Returns:
a Tensor with shape [1, target_length, source_length]
"""
return tf.expand_dims(
tf.cast(tf.less(tf.expand_dims(tf.range(target_length), axis=0),
tf.expand_dims(tf.range(source_length), axis=1)),
dtype=tf.float32), axis=0)
def relu_density_logit(x, reduce_dims):
"""logit(density(x)).
Useful for histograms.
Args:
x: a Tensor, typically the output of tf.relu
reduce_dims: a list of dimensions
Returns:
a Tensor
"""
frac = tf.reduce_mean(to_float(x > 0.0), reduce_dims)
scaled = tf.log(frac + math.exp(-10)) - tf.log((1.0 - frac) + math.exp(-10))
return scaled
def maybe_zero_out_padding(inputs, kernel_size, nonpadding_mask):
"""If necessary, zero out inputs to a conv for padding positions.
Args:
inputs: a Tensor with shape [batch, length, ...]
kernel_size: an integer or pair of integers
nonpadding_mask: a Tensor with shape [batch, length]
Returns:
Tensor of the same shape as inputs.
"""
if (kernel_size != 1 and kernel_size != (1, 1) and
nonpadding_mask is not None):
while nonpadding_mask.get_shape().ndims < inputs.get_shape().ndims:
nonpadding_mask = tf.expand_dims(nonpadding_mask, -1)
return inputs * nonpadding_mask
return inputs
def dense_relu_dense(inputs,
filter_size,
output_size,
output_activation=None,
dropout=0.0,
dropout_broadcast_dims=None,
layer_collection=None,
name=None):
"""Hidden layer with RELU activation followed by linear projection."""
# layer_name is appended with "conv1" or "conv2" in this method only for
# historical reasons. These are in fact dense layers.
layer_name = "%s_{}" % name if name else "{}"
h = dense(
inputs,
filter_size,
use_bias=True,
activation=tf.nn.relu,
layer_collection=layer_collection,
name=layer_name.format("conv1"))
if dropout != 0.0:
h = dropout_with_broadcast_dims(
h, 1.0 - dropout, broadcast_dims=dropout_broadcast_dims)
o = dense(
h,
output_size,
activation=output_activation,
use_bias=True,
layer_collection=layer_collection,
name=layer_name.format("conv2"))
return o
def dense_dropconnect(inputs,
output_size,
dropconnect_dropout=0.0,
name="dense_dropconnect",
**kwargs):
"""Dense layer with dropconnect."""
if dropconnect_dropout != 0.0:
tf.logging.info("Applying dropconnect as the kernel regularization.")
kwargs["kernel_regularizer"] = functools.partial(
tf.nn.dropout, keep_prob=1.0 - dropconnect_dropout)
return dense(inputs, output_size, use_bias=True, name=name, **kwargs)
def conv_relu_conv(inputs,
filter_size,
output_size,
first_kernel_size=3,
second_kernel_size=3,
padding="SAME",
nonpadding_mask=None,
dropout=0.0,
name=None,
cache=None,
decode_loop_step=None):
"""Hidden layer with RELU activation followed by linear projection.
Args:
inputs: A tensor.
filter_size: An integer.
output_size: An integer.
first_kernel_size: An integer.
second_kernel_size: An integer.
padding: A string.
nonpadding_mask: A tensor.
dropout: A float.
name: A string.
cache: A dict, containing Tensors which are the results of previous
attentions, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop.
Only used for inference on TPU. If it is not None, the function
will do inplace update for the cache instead of concatenating the
current result to the cache.
Returns:
A Tensor.
"""
with tf.variable_scope(name, "conv_relu_conv", [inputs]):
inputs = maybe_zero_out_padding(inputs, first_kernel_size, nonpadding_mask)
if cache:
if decode_loop_step is None:
inputs = cache["f"] = tf.concat([cache["f"], inputs], axis=1)
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
# The performance of current implementation is better than updating
# the tensor by adding the result of matmul(one_hot,
# update_in_current_step)
tmp_f = tf.transpose(cache["f"], perm=[1, 0, 2])
tmp_f = inplace_ops.alias_inplace_update(
tmp_f,
decode_loop_step * tf.shape(inputs)[1],
tf.transpose(inputs, perm=[1, 0, 2]))
inputs = cache["f"] = tf.transpose(tmp_f, perm=[1, 0, 2])
inputs = cache["f"] = inputs[:, -first_kernel_size:, :]
h = tpu_conv1d(
inputs, filter_size, first_kernel_size, padding=padding, name="conv1")
if cache:
h = h[:, -1:, :]
h = tf.nn.relu(h)
if dropout != 0.0:
h = tf.nn.dropout(h, 1.0 - dropout)
h = maybe_zero_out_padding(h, second_kernel_size, nonpadding_mask)
return tpu_conv1d(
h, output_size, second_kernel_size, padding=padding, name="conv2")
def sepconv_relu_sepconv(inputs,
filter_size,
output_size,
first_kernel_size=(1, 1),
second_kernel_size=(1, 1),
padding="LEFT",
nonpadding_mask=None,
dropout=0.0,
name=None):
"""Hidden layer with RELU activation followed by linear projection."""
with tf.variable_scope(name, "sepconv_relu_sepconv", [inputs]):
inputs = maybe_zero_out_padding(inputs, first_kernel_size, nonpadding_mask)
if inputs.get_shape().ndims == 3:
is_3d = True
inputs = tf.expand_dims(inputs, 2)
else:
is_3d = False
h = separable_conv(
inputs,
filter_size,
first_kernel_size,
activation=tf.nn.relu,
padding=padding,
name="conv1")
if dropout != 0.0:
h = tf.nn.dropout(h, 1.0 - dropout)
h = maybe_zero_out_padding(h, second_kernel_size, nonpadding_mask)
ret = separable_conv(
h, output_size, second_kernel_size, padding=padding, name="conv2")
if is_3d:
ret = tf.squeeze(ret, 2)
return ret
# DEPRECATED - use dense_relu_dense, conv_relu_conv, sepconv_relu_sepconv
def conv_hidden_relu(inputs,
hidden_size,
output_size,
kernel_size=(1, 1),
second_kernel_size=(1, 1),
dropout=0.0,
**kwargs):
"""Hidden layer with RELU activation followed by linear projection."""
name = kwargs.pop("name") if "name" in kwargs else None
with tf.variable_scope(name, "conv_hidden_relu", [inputs]):
if inputs.get_shape().ndims == 3:
is_3d = True
inputs = tf.expand_dims(inputs, 2)
else:
is_3d = False
conv_f1 = conv if kernel_size == (1, 1) else separable_conv
h = conv_f1(
inputs,
hidden_size,
kernel_size,
activation=tf.nn.relu,
name="conv1",
**kwargs)
if dropout != 0.0:
h = tf.nn.dropout(h, 1.0 - dropout)
conv_f2 = conv if second_kernel_size == (1, 1) else separable_conv
ret = conv_f2(h, output_size, second_kernel_size, name="conv2", **kwargs)
if is_3d:
ret = tf.squeeze(ret, 2)
return ret
def conv_gru(x,
kernel_size,
filters,
padding="SAME",
dilation_rate=(1, 1),
name=None,
reuse=None):
"""Convolutional GRU in 1 dimension."""
# Let's make a shorthand for conv call first.
def do_conv(args, name, bias_start, padding):
return conv(
args,
filters,
kernel_size,
padding=padding,
dilation_rate=dilation_rate,
bias_initializer=tf.constant_initializer(bias_start),
name=name)
# Here comes the GRU gate.
with tf.variable_scope(
name, default_name="conv_gru", values=[x], reuse=reuse):
reset = saturating_sigmoid(do_conv(x, "reset", 1.0, padding))
gate = saturating_sigmoid(do_conv(x, "gate", 1.0, padding))
candidate = tf.tanh(do_conv(reset * x, "candidate", 0.0, padding))
return gate * x + (1 - gate) * candidate
def gru_feedfwd(a_t, h_prev, filters, name=None):
"""position-wise Feed-fwd GRU gates following the MPNN.
Args:
a_t: Tensor of shape [batch, length, depth] of current input
h_prev: Tensor of shape [batch, length, depth] of prev input
filters: an integer specifying number of dimensions of the filters
name: A string
Returns:
h_t: [batch, length, filters] hidden state
"""
with tf.variable_scope(name, default_name="GRU", values=[a_t, h_prev]):
# we use right matrix multiplication to handle batches
# W_z and W_r have shape 2d, d. U_z U_r have shape d,d
z_t = (
tf.sigmoid(
tpu_conv1d(a_t, filters, 1, padding="SAME", name="W_z") +
tpu_conv1d(h_prev, filters, 1, padding="SAME", name="U_z")))
r_t = (
tf.sigmoid(
tpu_conv1d(a_t, filters, 1, padding="SAME", name="W_r") +
tpu_conv1d(h_prev, filters, 1, padding="SAME", name="U_r")))
h_tilde = (
tf.tanh(
tpu_conv1d(a_t, filters, 1, padding="SAME", name="W") +
tpu_conv1d(r_t * h_prev, filters, 1, padding="SAME", name="U")))
h_t = (1. - z_t) * h_prev + z_t * h_tilde
return h_t
def conv_lstm(x,
kernel_size,
filters,
padding="SAME",
dilation_rate=(1, 1),
name=None,
reuse=None):
"""Convolutional LSTM in 1 dimension."""
with tf.variable_scope(
name, default_name="conv_lstm", values=[x], reuse=reuse):
gates = conv(
x,
4 * filters,
kernel_size,
padding=padding,
dilation_rate=dilation_rate)
g = tf.split(layer_norm(gates, 4 * filters), 4, axis=3)
new_cell = tf.sigmoid(g[0]) * x + tf.sigmoid(g[1]) * tf.tanh(g[3])
return tf.sigmoid(g[2]) * tf.tanh(new_cell)
def diagonal_conv_gru(x,
kernel_size,
filters,
dropout=0.0,
name=None,
reuse=None):
"""Diagonal Convolutional GRU as in https://arxiv.org/abs/1702.08727."""
# Let's make a shorthand for conv call first.
def do_conv(args, name, bias_start):
return conv(
args,
filters,
kernel_size,
padding="SAME",
bias_initializer=tf.constant_initializer(bias_start),
name=name)
# Here comes the GRU gate.
with tf.variable_scope(
name, default_name="diagonal_conv_gru", values=[x], reuse=reuse):
reset, reset_cost = hard_sigmoid(do_conv(x, "reset", 0.5))
gate, gate_cost = hard_sigmoid(do_conv(x, "gate", 0.7))
candidate = tf.tanh(do_conv(reset * x, "candidate", 0.0))
if dropout > 0.0:
candidate = tf.nn.dropout(candidate, 1.0 - dropout)
# Diagonal shift.
shift_filters = filters // 3
base_filter = ([[0, 1, 0]] * (filters - 2 * shift_filters) +
[[1, 0, 0]] * shift_filters + [[0, 0, 1]] * shift_filters)
shift_filter = tf.constant(np.transpose(base_filter), dtype=tf.float32)
shift_filter = tf.expand_dims(tf.expand_dims(shift_filter, 0), 3)
x_shifted = tf.nn.depthwise_conv2d(
x, shift_filter, [1, 1, 1, 1], padding="SAME")
# Return the gated result and cost.
total_cost_avg = 0.5 * (reset_cost + gate_cost)
return gate * x_shifted + (1 - gate) * candidate, total_cost_avg
def pad_to_same_length(x, y, final_length_divisible_by=1, axis=1):
"""Pad tensors x and y on axis 1 so that they have the same length."""
if axis not in [1, 2]:
raise ValueError("Only axis=1 and axis=2 supported for now.")
with tf.name_scope("pad_to_same_length", values=[x, y]):
x_length = shape_list(x)[axis]
y_length = shape_list(y)[axis]
if (isinstance(x_length, int) and isinstance(y_length, int) and
x_length == y_length and final_length_divisible_by == 1):
return x, y
max_length = tf.maximum(x_length, y_length)
if final_length_divisible_by > 1:
# Find the nearest larger-or-equal integer divisible by given number.
max_length += final_length_divisible_by - 1
max_length //= final_length_divisible_by
max_length *= final_length_divisible_by
length_diff1 = max_length - x_length
length_diff2 = max_length - y_length
def padding_list(length_diff, arg):
if axis == 1:
return [[[0, 0], [0, length_diff]],
tf.zeros([tf.rank(arg) - 2, 2], dtype=tf.int32)]
return [[[0, 0], [0, 0], [0, length_diff]],
tf.zeros([tf.rank(arg) - 3, 2], dtype=tf.int32)]
paddings1 = tf.concat(padding_list(length_diff1, x), axis=0)
paddings2 = tf.concat(padding_list(length_diff2, y), axis=0)
res_x = tf.pad(x, paddings1)
res_y = tf.pad(y, paddings2)
# Static shapes are the same except for axis=1.
x_shape = x.shape.as_list()
x_shape[axis] = None
res_x.set_shape(x_shape)
y_shape = y.shape.as_list()
y_shape[axis] = None
res_y.set_shape(y_shape)
return res_x, res_y
def pad_with_zeros(logits, labels):
"""Pad labels on the length dimension to match logits length."""
with tf.name_scope("pad_with_zeros", values=[logits, labels]):
logits, labels = pad_to_same_length(logits, labels)
if len(labels.shape) == 3: # 2-d labels.
logits, labels = pad_to_same_length(logits, labels, axis=2)
return logits, labels
def weights_nonzero(labels):
"""Assign weight 1.0 to all labels except for padding (id=0)."""
return to_float(tf.not_equal(labels, 0))
def weights_prepend_inputs_to_targets(labels):
"""Assign weight 1.0 to only the "targets" portion of the labels.
Weight 1.0 is assigned to all nonzero labels past the first zero.
See prepend_mode in common_hparams.py
Args:
labels: A Tensor of int32s.
Returns:
A Tensor of floats.
"""
past_first_zero = tf.cumsum(to_float(tf.equal(labels, 0)), axis=1)
nonzero = to_float(labels)
return to_float(tf.not_equal(past_first_zero * nonzero, 0))
def check_nonnegative(value):
"""Check that the value is nonnegative."""
if isinstance(value, tf.Tensor):
with tf.control_dependencies([tf.assert_greater_equal(value, 0)]):
value = tf.identity(value)
elif value < 0:
raise ValueError("Value must be non-negative.")
return value
def weights_multi_problem(labels, taskid=-1):
"""Assign weight 1.0 to only the "targets" portion of the labels.
Weight 1.0 is assigned to all labels past the taskid.
Args:
labels: A Tensor of int32s.
taskid: an int32 representing the task id for a problem.
Returns:
A Tensor of floats.
Raises:
ValueError: The Task ID must be valid.
"""
taskid = check_nonnegative(taskid)
past_taskid = tf.cumsum(to_float(tf.equal(labels, taskid)), axis=1)
# Additionally zero out the task id location
past_taskid *= to_float(tf.not_equal(labels, taskid))
non_taskid = to_float(labels)
return to_float(tf.not_equal(past_taskid * non_taskid, 0))
def weights_multi_problem_all(labels, taskid=-1):
"""Assign weight 1.0 to only examples from the given task."""
taskid = check_nonnegative(taskid)
weights = to_float(tf.not_equal(labels, 0))
past_taskid = tf.cumsum(to_float(tf.equal(labels, taskid)), axis=1)
# Additionally zero out the task id location
past_taskid *= to_float(tf.not_equal(labels, taskid))
non_taskid = to_float(labels)
example_mask = to_float(tf.not_equal(past_taskid * non_taskid, 0))
example_mask = tf.reduce_sum(example_mask, axis=1)
example_mask = to_float(
tf.greater(example_mask, tf.zeros_like(example_mask)))
return weights * tf.expand_dims(example_mask, axis=-1)
def weights_multi_problem_input(labels, taskid=-1):
"""Assign weight 1.0 to only the inputs for the given task."""
taskid = check_nonnegative(taskid)
weights_all_tokens = weights_multi_problem_all(labels, taskid)
weights_target = weights_multi_problem(labels, taskid)
return weights_all_tokens - weights_target
def weights_all(labels):
"""Assign weight 1.0 to all labels."""
return tf.ones_like(labels, dtype=tf.float32)
def weights_concatenated(labels):
"""Assign weight 1.0 to the "target" part of the concatenated labels.
The labels look like:
source English I love you . ID1 target French Je t'aime . ID1 source
English the cat ID1 target French le chat ID1 source English ...
We want to assign weight 1.0 to all words in the target text (including the
ID1 end symbol), but not to the source text or the boilerplate. In the
above example, the target words that get positive weight are:
Je t'aime . ID1 le chat ID1
Args:
labels: a Tensor
Returns:
a Tensor
"""
eos_mask = tf.to_int32(tf.equal(labels, 1))
sentence_num = tf.cumsum(eos_mask, axis=1, exclusive=True)
in_target = tf.equal(tf.mod(sentence_num, 2), 1)
# first two tokens of each sentence are boilerplate.
sentence_num_plus_one = sentence_num + 1
shifted = tf.pad(sentence_num_plus_one,
[[0, 0], [2, 0], [0, 0], [0, 0]])[:, :-2, :, :]
nonboilerplate = tf.equal(sentence_num_plus_one, shifted)
ret = to_float(tf.logical_and(nonboilerplate, in_target))
return ret
def padded_cross_entropy(logits,
labels,
label_smoothing,
weights_fn=weights_nonzero,
reduce_sum=True,
cutoff=0.0,
gaussian=False):
"""Compute cross-entropy assuming 0s are padding.
Computes a loss numerator (the sum of losses), and loss denominator
(the number of non-padding tokens).
Args:
logits: a `Tensor` with shape `[batch, timesteps, vocab_size]`.
optionally a FactoredTensor.
labels: an integer `Tensor` with shape `[batch, timesteps]`.
label_smoothing: a floating point `Scalar`.
weights_fn: A function from labels to weights.
reduce_sum: a Boolean, whether to sum at the end or not.
cutoff: a float, at which point to have no loss.
gaussian: If true, use a Gaussian distribution for label smoothing
Returns:
loss_numerator: a `Scalar`. Sum of losses.
loss_denominator: a `Scalar. The number of non-padding target tokens.
Raises:
ValueError: in case of unsupported argument types.
"""
if isinstance(logits, FactoredTensor):
if gaussian:
raise ValueError("Factored padded cross entropy with Gaussian smoothing "
"is not implemented yet.")
return padded_cross_entropy_factored(
logits,
labels,
label_smoothing,
weights_fn=weights_fn,
reduce_sum=reduce_sum)
confidence = 1.0 - label_smoothing
logits_shape = shape_list(logits)
vocab_size = logits_shape[-1]
with tf.name_scope("padded_cross_entropy", values=[logits, labels]):
if len(logits_shape) == 2:
# Deal with the case where we did not insert extra dimensions due to
# TPU issues. No pad-to-same-length happens in this case.
# TODO(noam): remove this logic once TPU can handle extra dimensions.
labels = tf.reshape(labels, [-1])
else:
logits, labels = pad_with_zeros(logits, labels)
logits = tf.reshape(
logits,
shape_list(labels) + [vocab_size],
name="padded_cross_entropy_size_check")
logits = tf.cast(logits, tf.float32)
xent = smoothing_cross_entropy(
logits, labels, vocab_size, confidence, gaussian=gaussian)
weights = weights_fn(labels)
if cutoff > 0.0:
xent = tf.nn.relu(xent - cutoff)
if not reduce_sum:
return xent * weights, weights
return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
def _weights_one_third(labels):
"""Returns Tensor of shape [batch, height, width]. Each element is 1/3."""
return tf.ones(tf.shape(labels)[:-1]) / 3.
def dml_loss(pred, labels, weights_fn=_weights_one_third, reduce_sum=True):
"""Discretized mixture of logistics loss.
Args:
pred: A [batch, height, width, num_mixtures*10] tensor of floats
comprising one unconstrained mixture probability, three means
(one per channel), three standard deviations (one per channel),
and three coefficients which linearly parameterize dependence across
channels.
labels: A [batch, height, width, channels] tensor of 8-bit pixel
intensities. The computation assumes channels is 3.
weights_fn: A function of labels, returning a Tensor of shape
[batch, height, width] which weights each loss term. Default is to scale
each loss term by 1/3 so that they capture the average across channels.
reduce_sum: A boolean, to return scalar loss instead of per position.
Returns:
Tuple of loss tensors for numerator and denominator, each a scalar if
reduce_sum else of shape [batch, height, width]. The sum of their divisions
is the number of nats for each pixel in labels.
"""
real_labels = convert_rgb_to_symmetric_real(labels)
dml_loss_value = discretized_mix_logistic_loss(pred=pred, labels=real_labels)
weights = weights_fn(labels)
loss_num = weights * dml_loss_value
loss_den = weights_nonzero(weights)
if reduce_sum:
loss_num = tf.reduce_sum(loss_num)
loss_den = tf.reduce_sum(loss_den)
return loss_num, loss_den
def split_to_discretized_mix_logistic_params(inputs):
"""Splits input tensor into parameters of discretized mixture logistic.
Args:
inputs: A [batch, height, width, num_mixtures*10] tensor of floats
comprising one unconstrained mixture probability, three means
(one per channel), three standard deviations (one per channel),
and three coefficients which linearly parameterize dependence across
channels.
Returns:
Tuple of unconstrained mixture probabilities, locations, scales, and
coefficient parameters of the distribution. The mixture probability has
shape [batch, height, width, num_mixtures]. Other parameters have shape
[batch, height, width, num_mixtures, 3].
"""
batch, height, width, output_dim = shape_list(inputs) # pylint: disable=unbalanced-tuple-unpacking
num_mixtures = output_dim // 10
logits, locs, log_scales, coeffs = tf.split(
inputs,
num_or_size_splits=[
num_mixtures, num_mixtures * 3, num_mixtures * 3, num_mixtures * 3
],
axis=-1)
split_shape = [batch, height, width, num_mixtures, 3]
locs = tf.reshape(locs, split_shape)
log_scales = tf.reshape(log_scales, split_shape)
log_scales = tf.maximum(log_scales, -7.)
coeffs = tf.reshape(coeffs, split_shape)
coeffs = tf.tanh(coeffs)
return logits, locs, log_scales, coeffs
def discretized_mix_logistic_loss(pred, labels):
"""Computes negative log probability for the discretized mixture of logistics.
The distribution of a whole pixel is a mixture of 3-dimensional discretized
logistic distributions. The 3-D discretized logistic factorizes as 3 1-D
discretized logistic distributions, one for each channel. It defines
```none
P(X = x)
= sum_{k=1}^K probs[k] * P(X = x | locs[k], scales[k])
= sum_{k=1}^K probs[k] * [
prod_{c=1}^3 DiscretizedLogistic(X[c] = x[c] | means[k][c], scales[k]) ]
```
The means tensor is a linear combination of location parameters and previous
channels. The discretized logistic distribution assigns probability mass to an
event P(X=x) via logistic CDFs: P(X <= x + 0.5) - P(X < x - 0.5) for 1 < x <
254; P(X <= 0.5) for x = 0; and 1 - P(X < 245.5) for x = 255. Instead of
8-bit inputs, this implementation assumes the events are rescaled to [-1, 1].
Args:
pred: A [batch, height, width, num_mixtures*10] tensor of floats
comprising one unconstrained mixture probability, three means
(one per channel), three standard deviations (one per channel),
and three coefficients which linearly parameterize dependence across
channels.
labels: A [batch, height, width, channels] tensor of true pixel intensities
rescaled to [-1, 1]. The computation assumes channels is 3.
Returns:
A [batch, height, width] tensor of the negative log conditional probability
of each pixel given all previous pixels.
"""
logits, locs, log_scales, coeffs = split_to_discretized_mix_logistic_params(
pred)
# Tile labels to broadcast compute across the mixture dimension.
batch, height, width, num_mixtures = shape_list(logits) # pylint: disable=unbalanced-tuple-unpacking
labels = tf.tile(
tf.reshape(labels, [batch, height, width, 1, 3]),
[1, 1, 1, num_mixtures, 1])
# p(x) = sigmoid((x - means_i + 1/255.)/scale_i) -
# sigmoid((x - means_i - 1/255.)/scale_i)
# for each channel i. The means are linearly parameterized.
means_0 = locs[..., 0]
means_1 = locs[..., 1] + coeffs[..., 0] * labels[..., 0]
means_2 = (
locs[..., 2] + coeffs[..., 1] * labels[..., 0] +
coeffs[..., 2] * labels[..., 1])
means = tf.stack([means_0, means_1, means_2], axis=-1)
centered_labels = labels - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_labels + 1. / 255.)
min_in = inv_stdv * (centered_labels - 1. / 255.)
cdf_plus = tf.nn.sigmoid(plus_in)
cdf_min = tf.nn.sigmoid(min_in)
# Compute log probability for edge case of 0 (before scaling), 255 (before
# scaling), and all other cases respectively.
log_prob_0 = plus_in - tf.nn.softplus(plus_in)
log_prob_255 = -tf.nn.softplus(min_in)
prob_event = tf.maximum(cdf_plus - cdf_min, 1e-12)
log_prob_event = tf.log(prob_event)
# Robustly select log-prob based on numerical edge-cases: (a) [-1, -1+eps);
# (b) (1-eps, 1]; (c) NaNs during `tf.gradients` of `tf.select`, which may
# cause `tf.log(0.)`; (d) p(x) < 1e-5.
mid_in = inv_stdv * centered_labels
log_prob_event_approx = (
mid_in - log_scales - 2. * tf.nn.softplus(mid_in) - np.log(127.5))
log_probs = tf.where(
labels < -0.999, log_prob_0,
tf.where(
labels > 0.999, log_prob_255,
tf.where(prob_event > 1e-5, log_prob_event, log_prob_event_approx)))
# Sum over channels and compute log-probability of each mixture.
log_probs = tf.reduce_sum(log_probs, -1) + tf.nn.log_softmax(logits, axis=-1)
output = -tf.reduce_logsumexp(log_probs, axis=-1)
return output
def sample_from_discretized_mix_logistic(pred, seed=None):
"""Sampling from a discretized mixture of logistics.
Args:
pred: A [batch, height, width, num_mixtures*10] tensor of floats
comprising one unconstrained mixture probability, three means
(one per channel), three standard deviations (one per channel),
and three coefficients which linearly parameterize dependence across
channels.
seed: Random seed.
Returns:
A tensor of shape [batch, height, width, 3] with real intensities scaled
between -1 and 1.
"""
logits, locs, log_scales, coeffs = split_to_discretized_mix_logistic_params(
pred)
# Sample mixture indicator given logits using the gumbel max trick.
num_mixtures = shape_list(logits)[-1]
gumbel_noise = -tf.log(-tf.log(
tf.random_uniform(
tf.shape(logits), minval=1e-5, maxval=1. - 1e-5, seed=seed)))
sel = tf.one_hot(
tf.argmax(logits + gumbel_noise, -1),
depth=num_mixtures,
dtype=tf.float32)
# Select mixture component's parameters.
sel = tf.expand_dims(sel, -1)
locs = tf.reduce_sum(locs * sel, 3)
log_scales = tf.reduce_sum(log_scales * sel, 3)
coeffs = tf.reduce_sum(coeffs * sel, 3)
# Sample from 3-D logistic & clip to interval. Note we don't round to the
# nearest 8-bit value when sampling.
uniform_noise = tf.random_uniform(
tf.shape(locs), minval=1e-5, maxval=1. - 1e-5, seed=seed)
logistic_noise = tf.log(uniform_noise) - tf.log1p(-uniform_noise)
x = locs + tf.exp(log_scales) * logistic_noise
x0 = x[..., 0]
x1 = x[..., 1] + coeffs[..., 0] * x0
x2 = x[..., 2] + coeffs[..., 1] * x0 + coeffs[..., 2] * x1
x = tf.stack([x0, x1, x2], axis=-1)
x = tf.clip_by_value(x, -1., 1.)
return x
def smoothing_cross_entropy(logits,
labels,
vocab_size,
confidence,
gaussian=False):
"""Cross entropy with label smoothing to limit over-confidence.
Args:
logits: Tensor of shape [batch_size, ?, ?, ?, vocab_size].
labels: Tensor of shape [batch_size, ?, ?, ?].
vocab_size: Tensor representing the size of the vocabulary.
confidence: Used to determine on and off values for label smoothing.
If `gaussian` is true, `confidence` is the variance to the Gaussian
distribution.
gaussian: Uses a Gaussian distribution for label smoothing
Returns:
Tensor of shape [batch_size, ?, ?, ?].
"""
with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
# Low confidence is given to all non-true labels, uniformly.
low_confidence = (1.0 - confidence) / to_float(vocab_size - 1)
# Normalizing constant is the best cross-entropy value with soft targets.
# We subtract it just for readability, makes no difference on learning.
normalizing = -(
confidence * tf.log(confidence) + to_float(vocab_size - 1) *
low_confidence * tf.log(low_confidence + 1e-20))
if gaussian and confidence > 0.0:
labels = tf.cast(labels, tf.float32)
normal_dist = tfp.distributions.Normal(loc=labels, scale=confidence)
# Locations to evaluate the probability distributions.
soft_targets = normal_dist.prob(
tf.cast(tf.range(vocab_size), tf.float32)[:, None, None, None, None])
# Reordering soft_targets from [vocab_size, batch_size, ?, ?, ?] to match
# logits: [batch_size, ?, ?, ?, vocab_size]
soft_targets = tf.transpose(soft_targets, perm=[1, 2, 3, 4, 0])
else:
soft_targets = tf.one_hot(
tf.cast(labels, tf.int32),
depth=vocab_size,
on_value=confidence,
off_value=low_confidence)
xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=soft_targets)
return xentropy - normalizing
def global_pool_1d(inputs, pooling_type="MAX", mask=None):
"""Pool elements across the last dimension.
Useful to convert a list of vectors into a single vector so as
to get a representation of a set.
Args:
inputs: A tensor of shape [batch_size, sequence_length, input_dims]
containing the sequences of input vectors.
pooling_type: the pooling type to use, MAX or AVR
mask: A tensor of shape [batch_size, sequence_length] containing a
mask for the inputs with 1's for existing elements, and 0's elsewhere.
Returns:
A tensor of shape [batch_size, input_dims] containing the sequences of
transformed vectors.
"""
with tf.name_scope("global_pool", values=[inputs]):
if mask is not None:
mask = tf.expand_dims(mask, axis=2)
inputs = tf.multiply(inputs, mask)
if pooling_type == "MAX":
# A tf.pool can be used here, but reduce is cleaner
output = tf.reduce_max(inputs, axis=1)
elif pooling_type == "AVR":
if mask is not None:
# Some elems are dummy elems so we can't just reduce the average.
output = tf.reduce_sum(inputs, axis=1)
num_elems = tf.reduce_sum(mask, axis=1, keepdims=True)
output = tf.div(output, tf.maximum(num_elems, 1))
else:
output = tf.reduce_mean(inputs, axis=1)
return output
def running_global_pool_1d(inputs, pooling_type="MAX"):
"""Same global pool, but only for the elements up to the current element.
Useful for outputs where the state of future elements is not known.
Takes no mask as all elements up to the current element are assumed to exist.
Currently only supports maximum. Equivalent to using a lower triangle bias.
Args:
inputs: A tensor of shape [batch_size, sequence_length, input_dims]
containing the sequences of input vectors.
pooling_type: Pooling type to use. Currently only supports 'MAX'.
Returns:
A tensor of shape [batch_size, sequence_length, input_dims] containing the
running 'totals'.
"""
del pooling_type
with tf.name_scope("running_global_pool", values=[inputs]):
scan_fct = tf.maximum
# Permute inputs so seq_length is first.
elems = tf.transpose(inputs, [1, 0, 2])
# Perform scan.
cumulatives = tf.scan(scan_fct, elems, swap_memory=True)
# Permute output to get back to original order.
output = tf.transpose(cumulatives, [1, 0, 2])
return output
def gated_linear_unit_layer(x, name=None):
"""Gated linear unit layer.
Paper: Language Modeling with Gated Convolutional Networks.
Link: https://arxiv.org/abs/1612.08083
x = Wx * sigmoid(W'x).
Args:
x: A tensor
name: A string
Returns:
A tensor of the same shape as x.
"""
with tf.variable_scope(name, default_name="glu_layer", values=[x]):
depth = shape_list(x)[-1]
x = layers().Dense(depth * 2, activation=None)(x)
x, gating_x = tf.split(x, 2, axis=-1)
return x * tf.nn.sigmoid(gating_x)
def sru(x,
num_layers=2,
activation=None,
initial_state=None,
name=None,
reuse=None):
"""SRU cell as in https://arxiv.org/abs/1709.02755.
This implementation uses tf.scan and can incur overhead, see the full SRU
function doc for details and an implementation that is sometimes faster.
Args:
x: A tensor of shape [batch, ..., channels] ; ... is treated as time.
num_layers: How many SRU layers; default is 2 as results for 1 disappoint.
activation: Optional activation function, try tf.nn.tanh or tf.nn.relu.
initial_state: Optional initial c-state, set to zeros if None.
name: Optional name, "sru" by default.
reuse: Optional reuse.
Returns:
A tensor of the same shape as x.
Raises:
ValueError: if num_layers is not positive.
"""
if num_layers < 1:
raise ValueError("Number of layers must be positive: %d" % num_layers)
with tf.variable_scope(name, default_name="sru", values=[x], reuse=reuse):
# We assume x is [batch, ..., channels] and treat all ... as time.
x_shape = shape_list(x)
x = tf.reshape(x, [x_shape[0], -1, x_shape[-1]])
x = tf.transpose(x, [1, 0, 2]) # Scan assumes time on axis 0.
initial_state = initial_state or tf.zeros([x_shape[0], x_shape[-1]])
# SRU state manipulation function.
def next_state(cur_state, args_tup):
cur_x_times_one_minus_f, cur_f = args_tup
return cur_f * cur_state + cur_x_times_one_minus_f
# Calculate SRU on each layer.
for i in range(num_layers):
# The parallel part of the SRU.
x_orig = x
x, f, r = tf.split(
layers().Dense(3 * x_shape[-1], name="kernel_%d" % i)(x), 3, axis=-1)
f, r = tf.sigmoid(f), tf.sigmoid(r)
x_times_one_minus_f = x * (1.0 - f) # Compute in parallel for speed.
# Calculate states.
c_states = tf.scan(
next_state, (x_times_one_minus_f, f),
initializer=initial_state,
parallel_iterations=2,
name="scan_%d" % i)
# Final output.
if activation is not None:
c_states = activation(c_states)
h = c_states * r + (1.0 - r) * x_orig
x = h # Next layer.
# Transpose back to batch-major.
x = tf.transpose(x, [1, 0, 2])
return tf.reshape(x, x_shape)
def linear_set_layer(layer_size,
inputs,
context=None,
activation_fn=tf.nn.relu,
dropout=0.0,
name=None):
"""Basic layer type for doing funky things with sets.
Applies a linear transformation to each element in the input set.
If a context is supplied, it is concatenated with the inputs.
e.g. One can use global_pool_1d to get a representation of the set which
can then be used as the context for the next layer.
TODO: Add bias add (or control the biases used).
Args:
layer_size: Dimension to transform the input vectors to.
inputs: A tensor of shape [batch_size, sequence_length, input_dims]
containing the sequences of input vectors.
context: A tensor of shape [batch_size, context_dims] containing a global
statistic about the set.
activation_fn: The activation function to use.
dropout: Dropout probability.
name: name.
Returns:
Tensor of shape [batch_size, sequence_length, output_dims] containing the
sequences of transformed vectors.
"""
with tf.variable_scope(
name, default_name="linear_set_layer", values=[inputs]):
# Apply 1D convolution to apply linear filter to each element
# along the 2nd dimension.
outputs = conv1d(inputs, layer_size, 1, activation=None, name="set_conv")
# Apply the context if it exists.
if context is not None:
# Unfortunately tf doesn't support broadcasting via concat, but we can
# simply add the transformed context to get the same effect.
if len(context.get_shape().as_list()) == 2:
context = tf.expand_dims(context, axis=1)
cont_tfm = conv1d(
context, layer_size, 1, activation=None, name="cont_conv")
outputs += cont_tfm
if activation_fn is not None:
outputs = activation_fn(outputs)
if dropout != 0.0:
outputs = tf.nn.dropout(outputs, 1.0 - dropout)
return outputs
def ravanbakhsh_set_layer(layer_size,
inputs,
mask=None,
sequential=False,
activation_fn=tf.nn.tanh,
dropout=0.0,
name=None):
"""Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500 .
More parameter-efficient version of a linear-set-layer with context.
Args:
layer_size: Dimension to transform the input vectors to.
inputs: A tensor of shape [batch_size, sequence_length, vector]
containing the sequences of input vectors.
mask: A tensor of shape [batch_size, sequence_length] containing a
mask for the inputs with 1's for existing elements, and 0's elsewhere.
sequential: If true, will use a running global pool so each element will
only depend on those before it. Set true if this layer is being used in
an output sequence.
activation_fn: The activation function to use.
dropout: dropout.
name: name.
Returns:
Tensor of shape [batch_size, sequence_length, vector] containing the
sequences of transformed vectors.
"""
del dropout
with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]):
if sequential:
return linear_set_layer(
layer_size,
inputs - running_global_pool_1d(inputs),
activation_fn=activation_fn,
name=name)
return linear_set_layer(
layer_size,
inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1),
activation_fn=activation_fn,
name=name)
def fn_device_dependency_dict():
"""State container for fn_device_dependency."""
default_graph = tf.get_default_graph()
if not hasattr(default_graph, "dependency_dict"):
default_graph.dependency_dict = collections.defaultdict(list)
return default_graph.dependency_dict
@contextlib.contextmanager
def fn_device_dependency(name, device=""):
"""Add control deps for name and device."""
key = name + "_" + device
outs = []
def body():
with tf.control_dependencies(fn_device_dependency_dict()[key]):
yield outs
assert outs
deps = outs
if isinstance(outs[0], (list, tuple)):
assert len(outs) == 1
deps = outs[0]
fn_device_dependency_dict()[key] = deps
if device:
with tf.device(device):
return body()
else:
return body()
def underlying_variable_ref(t):
"""Find the underlying variable ref.
Traverses through Identity, ReadVariableOp, and Enter ops.
Stops when op type has Variable or VarHandle in name.
Args:
t: a Tensor
Returns:
a Tensor that is a variable ref, or None on error.
"""
while t.op.type in ["Identity", "ReadVariableOp", "Enter"]:
t = t.op.inputs[0]
op_type = t.op.type
if "Variable" in op_type or "VarHandle" in op_type:
return t
else:
return None
def underlying_variable(t):
"""Find the underlying tf.Variable object.
Args:
t: a Tensor
Returns:
tf.Variable.
"""
t = underlying_variable_ref(t)
assert t is not None
# make sure that the graph has a variable index and that it is up-to-date
if not hasattr(tf.get_default_graph(), "var_index"):
tf.get_default_graph().var_index = {}
var_index = tf.get_default_graph().var_index
for v in tf.global_variables()[len(var_index):]:
var_index[v.name] = v
return var_index[t.name]
def approximate_split(x, num_splits, axis=0):
"""Split approximately equally into num_splits parts.
Args:
x: a Tensor
num_splits: an integer
axis: an integer.
Returns:
a list of num_splits Tensors.
"""
size = shape_list(x)[axis]
size_splits = [tf.div(size + i, num_splits) for i in range(num_splits)]
return tf.split(x, size_splits, axis=axis)
class FactoredTensor(object):
"""A concise factored representation of Tensor as two tensors.
This class represents the tensor tf.matmul(a, b, transpose_b=True)
by storing the values of Tensors a and b.
The reason for this is that the product may be too big to fully realize at
once, so it can be realized a part at a time.
"a" may have extra leading dimensions, in which case they are flattened out
before computing the matrix product, then re-expanded afterwards.
"""
def __init__(self, a, b):
self._a = a
self._b = b
@property
def a(self):
return self._a
@property
def b(self):
return self._b
def to_tensor(self):
"""Convert to Tensor."""
a_shape = shape_list(self.a)
b_shape = shape_list(self.b)
inner_dim = b_shape[1]
result_dim = b_shape[0]
flat_a = tf.reshape(self.a, [-1, inner_dim])
product = tf.matmul(flat_a, self.b, transpose_b=True)
product_shape = a_shape[:-1] + [result_dim]
product = tf.reshape(product, product_shape)
product.set_shape(self.a.get_shape().as_list()[:-1] +
[self.b.get_shape()[0]])
return product
def _convert_factored_tensor_to_tensor(value, *args, **kwargs):
# call ops.convert_to_tensor to handle optional arguments appropriately
return ops.convert_to_tensor(value.to_tensor(), *args, **kwargs)
tf.register_tensor_conversion_function(FactoredTensor,
_convert_factored_tensor_to_tensor)
def smoothing_cross_entropy_factored_grad(op, dy):
"""Gradient function for smoothing_cross_entropy_factored."""
a = op.inputs[0]
b = op.inputs[1]
labels = op.inputs[2]
confidence = op.inputs[3]
num_splits = 16
vocab_size = shape_list(b)[0]
labels = approximate_split(labels, num_splits)
a = approximate_split(a, num_splits)
dy = approximate_split(dy, num_splits)
b_grad = None
a_grad_parts = []
deps = []
for part in range(num_splits):
with tf.control_dependencies(deps):
logits = tf.matmul(a[part], b, transpose_b=True)
output_part = smoothing_cross_entropy(logits, labels[part], vocab_size,
confidence)
a_grad_part, b_grad_part = tf.gradients(
ys=[output_part], xs=[a[part], b], grad_ys=[dy[part]])
a_grad_parts.append(a_grad_part)
if part > 0:
b_grad += b_grad_part
else:
b_grad = b_grad_part
deps = [b_grad, a_grad_part]
a_grad = tf.concat(a_grad_parts, 0)
return a_grad, b_grad, None, None
@function.Defun(
noinline=True,
python_grad_func=smoothing_cross_entropy_factored_grad,
compiled=True,
separate_compiled_gradients=True)
def smoothing_cross_entropy_factored(a, b, labels, confidence):
"""Memory-efficient computation of smoothing cross-entropy.
Avoids realizing the entire logits matrix at once.
Args:
a: a Tensor with shape [batch, inner_dim]
b: a Tensor with shape [vocab_size, inner_dim]
labels: an integer Tensor with shape [batch]
confidence: a float
Returns:
A Tensor with shape [batch]
"""
num_splits = 16
vocab_size = shape_list(b)[0]
labels = approximate_split(labels, num_splits)
a = approximate_split(a, num_splits)
parts = []
for part in range(num_splits):
with tf.control_dependencies(parts[-1:]):
logits = tf.matmul(a[part], b, transpose_b=True)
parts.append(
smoothing_cross_entropy(logits, labels[part], vocab_size, confidence))
return tf.concat(parts, 0)
def padded_cross_entropy_factored(factored_logits,
labels,
label_smoothing,
weights_fn=weights_nonzero,
reduce_sum=True):
"""Memory-efficient computation of smoothing cross-entropy.
Avoids realizing the entire logits matrix at once.
Args:
factored_logits: a `FactoredTensor` representing a Tensor
with shape `[batch, timesteps, vocab_size]`.
labels: an integer `Tensor` with shape `[batch, timesteps]`.
label_smoothing: a floating point `Scalar`.
weights_fn: A function from labels to weights.
reduce_sum: a Boolean, whether to sum at the end or not.
Returns:
loss_numerator: a `Scalar`. Sum of losses.
loss_denominator: a `Scalar. The number of non-padding target tokens.
"""
a = factored_logits.a
b = factored_logits.b
confidence = 1.0 - label_smoothing
with tf.name_scope("padded_cross_entropy_factored", values=[a, b, labels]):
labels_flat = tf.reshape(labels, [-1])
a_flat = tf.reshape(a, [-1, shape_list(b)[1]])
xent = smoothing_cross_entropy_factored(a_flat, b, labels_flat,
tf.convert_to_tensor(confidence))
xent = tf.reshape(xent, shape_list(labels))
weights = weights_fn(labels)
if not reduce_sum:
return xent * weights, weights
return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
def fn_with_custom_grad(grad_fn, use_global_vars=False):
"""Decorator to create a subgraph with a custom gradient function.
The subgraph created by the decorated function is NOT put in a Defun and so
does not suffer from the limitations of the Defun (all subgraph ops on the
same device, no summaries).
Args:
grad_fn: function with signature
(inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
all of which are lists of Tensors.
use_global_vars: if True, variables will be the global variables created.
If False, will be the trainable variables.
Returns:
Decorator for function such that the gradient is defined by grad_fn.
"""
def dec(fn):
@functools.wraps(fn)
def wrapped(*args):
return _fn_with_custom_grad(
fn, args, grad_fn, use_global_vars=use_global_vars)
return wrapped
return dec
def _fn_with_custom_grad(fn, inputs, grad_fn, use_global_vars=False):
"""Create a subgraph with a custom gradient.
Args:
fn: function that takes inputs as arguments and produces 1 or more Tensors.
inputs: list, will be passed as fn(*inputs).
grad_fn: function with signature
(inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
all of which are lists of Tensors.
use_global_vars: if True, variables will be the global variables created.
If False, will be the trainable variables.
Returns:
fn(*inputs)
"""
vs = tf.get_variable_scope()
get_vars_fn = (
vs.global_variables if use_global_vars else vs.trainable_variables)
len_before_vars = len(get_vars_fn())
inputs = list(inputs)
outputs = fn(*inputs)
train_vars = get_vars_fn()[len_before_vars:]
if grad_fn is None:
return outputs
if not isinstance(outputs, (tuple, list)):
outputs = [outputs]
outputs = list(outputs)
defun_inputs = [inputs, train_vars, outputs]
def custom_grad_fn(op, *dys):
"""Custom grad fn applying grad_fn for identity Defun."""
fn_inputs, fn_vars, fn_outputs = contrib.framework().nest.pack_sequence_as(
defun_inputs, list(op.inputs))
dys = list(dys)
assert len(fn_outputs) == len(outputs)
assert len(fn_outputs) == len(dys)
grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys)
grad_outputs = [None] * len(fn_outputs)
return tuple(grad_inputs + grad_vars + grad_outputs)
# The Defun takes as input the original inputs, the trainable variables
# created in fn, and the outputs. In the forward it passes through the
# outputs. In the backwards, it produces gradients for the original inputs
# and the trainable variables.
in_types = [t.dtype for t in inputs]
out_types = [t.dtype for t in outputs]
var_types = [t.dtype for t in train_vars]
@function.Defun(
*(in_types + var_types + out_types),
func_name="identity_custom_grad%d" % ops.uid(),
python_grad_func=custom_grad_fn,
shape_func=lambda _: [t.get_shape() for t in outputs])
def identity(*args):
_, _, outs = contrib.framework().nest.pack_sequence_as(defun_inputs, args)
return tuple([tf.identity(t) for t in outs])
flat_inputs = contrib.framework().nest.flatten(defun_inputs)
id_out = identity(*flat_inputs)
return id_out
_function_cache = {}
def conv_hidden_relu_memory_efficient(x,
filter_size,
epsilon=1e-6,
forget=True,
test_vars=None,
name=None):
"""LayerNorm, Conv, ReLU, Conv.
All convolutions have kernel size 1.
returns conv(relu(conv(layer_norm(x))))
Args:
x: input Tensor with shape [batch, length, io_size]
filter_size: an integer - size of the hidden layer.
epsilon: a float (for layer norm)
forget: a boolean - forget forwards activations and recompute on backprop
test_vars: optional tuple of variables for testing purposes
name: an optional string
Returns:
a Tensor with shape [batch, length, io_size]
"""
io_size = x.get_shape().as_list()[-1]
def forward_internal(x, f1, f2, scale, bias):
"""Forward function."""
# split batch-wise to avoid exhausting memory in cast the batch is large
# and the hidden layer is large.
num_splits = 4
x_flat = tf.reshape(x, [-1, 1, shape_list(x)[2]])
xs = approximate_split(x_flat, num_splits)
ys = []
for i in range(num_splits):
with tf.control_dependencies(ys[-1:]):
n = layer_norm_compute(xs[i], epsilon, scale, bias)
y = tf.nn.conv1d(n, f1, 1, "SAME")
y = tf.nn.relu(y)
y = tf.nn.conv1d(y, f2, 1, "SAME")
ys.append(y)
y = tf.concat(ys, 0)
y = tf.reshape(y, shape_list(x))
return y
key = ("conv_hidden_relu_memory_efficient %s" % epsilon)
if not forget:
forward_fn = forward_internal
elif key in _function_cache:
forward_fn = _function_cache[key]
else:
@function.Defun(compiled=True)
def grad_fn(x, f1, f2, scale, bias, dy):
"""Gradient for efficiency."""
with tf.control_dependencies([dy]):
num_splits = 4
x_shape = shape_list(x)
flat_shape = [-1, 1, x_shape[2]]
x = tf.reshape(x, flat_shape)
dy = tf.reshape(dy, flat_shape)
xs = approximate_split(x, num_splits)
dys = approximate_split(dy, num_splits)
dxs = []
df1 = 0
df2 = 0
dscale = 0
dbias = 0
deps = []
for i in range(num_splits):
with tf.control_dependencies(deps):
n = layer_norm_compute(xs[i], epsilon, scale, bias)
y = tf.nn.conv1d(n, f1, 1, "SAME")
y = tf.nn.relu(y)
y = tf.nn.conv1d(y, f2, 1, "SAME")
dxi, pdf1, pdf2, pdscale, pdbias = tf.gradients(
ys=[y], xs=[xs[i], f1, f2, scale, bias], grad_ys=[dys[i]])
df1 += pdf1
df2 += pdf2
dscale += pdscale
dbias += pdbias
dxs.append(dxi)
deps = [dxi, df1, df2, dscale, dbias]
with tf.control_dependencies(deps):
dx = tf.concat(dxs, 0)
dx = tf.reshape(dx, x_shape)
return dx, df1, df2, dscale, dbias
@function.Defun(
grad_func=grad_fn, compiled=True, separate_compiled_gradients=True)
def forward_fn(x, f1, f2, scale, bias):
return forward_internal(x, f1, f2, scale, bias)
with tf.variable_scope(name, default_name="ffn2", values=[x]):
# TODO(noam): it would be nice to save memory by casting x to float16
# here, but this causes problems with the gradients. Figure out if there
# is a way to leave the gradients as float32.
if test_vars is not None:
f1, f2, scale, bias = list(test_vars)
else:
f1 = tf.get_variable("f1", [1, io_size, filter_size])
f2 = tf.get_variable("f2", [1, filter_size, io_size])
scale, bias = layer_norm_vars(io_size)
if forget:
y = forward_fn(x, f1, f2, scale, bias)
else:
y = forward_internal(x, f1, f2, scale, bias)
y.set_shape(x.get_shape())
return y
def shape_list(x):
"""Return list of dims, statically where possible."""
x = tf.convert_to_tensor(x)
# If unknown rank, return dynamic shape
if x.get_shape().dims is None:
return tf.shape(x)
static = x.get_shape().as_list()
shape = tf.shape(x)
ret = []
for i, dim in enumerate(static):
if dim is None:
dim = shape[i]
ret.append(dim)
return ret
def list_product(els):
prod = els[0]
for el in els[1:]:
prod *= el
return prod
def sample_with_temperature(logits, temperature, sampling_keep_top_k=-1):
"""Either argmax or random sampling.
Args:
logits: a Tensor.
temperature: a float 0.0=argmax 1.0=random
sampling_keep_top_k: If not -1, only sample from the top k logits.
Returns:
a Tensor with one fewer dimension than logits.
"""
if temperature == 0.0:
# TF argmax doesn't handle >5 dimensions, so we reshape here.
logits_shape = shape_list(logits)
argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=1)
return tf.reshape(argmax, logits_shape[:-1])
else:
tf.debugging.assert_greater(temperature, 0.0)
if sampling_keep_top_k != -1:
if sampling_keep_top_k <= 0:
raise ValueError("sampling_keep_top_k must either be -1 or positive.")
vocab_size = shape_list(logits)[1]
k_largest = contrib.nn().nth_element(
logits, n=sampling_keep_top_k, reverse=True)
k_largest = tf.tile(tf.reshape(k_largest, [-1, 1]), [1, vocab_size])
# Force every position that is not in the top k to have probability near
# 0 by setting the logit to be very negative.
logits = tf.where(tf.less_equal(logits, k_largest),
tf.ones_like(logits)*-1e6, logits)
reshaped_logits = (
tf.reshape(logits, [-1, shape_list(logits)[-1]]) / temperature)
choices = tf.multinomial(reshaped_logits, 1)
choices = tf.reshape(choices,
shape_list(logits)[:logits.get_shape().ndims - 1])
return choices
def _select_top_k(logits, top_k):
"""Replaces logits, expect the top k highest values, with small number (-1e6).
If k is -1 don't replace anything.
Args:
logits: A `Tensor` of shape [batch_size, ..., vocab_size]
top_k: vector of batch size.
Returns:
A `Tensor` with same shape as logits.
"""
vocab_size = logits.shape[-1]
top_k = tf.where(
tf.not_equal(top_k, -1), top_k,
tf.ones_like(top_k) * vocab_size)
return tf.where(
tf.argsort(logits) < tf.reshape(top_k, [-1] + [1] *
(len(logits.shape) - 1)), logits,
tf.ones_like(logits) * -1e6)
def sample_temperature_per_example(logits, temperature, sampling_keep_top_k=-1):
"""Either random sampling with different temperature per example.
Args:
logits: a Tensor.
temperature: a float vector of same size as logits.
sampling_keep_top_k: If not -1, only sample from the top k logits.
Returns:
a Tensor with one fewer dimension than logits.
"""
logits = _select_top_k(logits, sampling_keep_top_k)
logits /= tf.reshape(temperature, [-1] + [1] * (len(logits.shape) - 1))
reshaped_logits = tf.reshape(logits, [-1, shape_list(logits)[-1]])
choices = tf.multinomial(reshaped_logits, 1)
choices = tf.reshape(choices,
shape_list(logits)[:logits.get_shape().ndims - 1])
return choices
def ones_matrix_band_part(rows, cols, num_lower, num_upper, out_shape=None):
"""Matrix band part of ones.
Args:
rows: int determining number of rows in output
cols: int
num_lower: int, maximum distance backward. Negative values indicate
unlimited.
num_upper: int, maximum distance forward. Negative values indicate
unlimited.
out_shape: shape to reshape output by.
Returns:
Tensor of size rows * cols reshaped into shape out_shape.
"""
if all([isinstance(el, int) for el in [rows, cols, num_lower, num_upper]]):
# Needed info is constant, so we construct in numpy
if num_lower < 0:
num_lower = rows - 1
if num_upper < 0:
num_upper = cols - 1
lower_mask = np.tri(cols, rows, num_lower).T
upper_mask = np.tri(rows, cols, num_upper)
band = np.ones((rows, cols)) * lower_mask * upper_mask
if out_shape:
band = band.reshape(out_shape)
band = tf.constant(band, tf.float32)
else:
band = tf.linalg.band_part(
tf.ones([rows, cols]), tf.cast(num_lower, tf.int64),
tf.cast(num_upper, tf.int64))
if out_shape:
band = tf.reshape(band, out_shape)
return band
def reshape_like_all_dims(a, b):
"""Reshapes a to match the shape of b."""
ret = tf.reshape(a, tf.shape(b))
if not tf.executing_eagerly():
ret.set_shape(b.get_shape())
return ret
def recompute_grad(fn):
"""Decorator that recomputes the function on the backwards pass.
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
a tuple of Tensors.
Returns:
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
"""
@functools.wraps(fn)
def wrapped(*args):
return _recompute_grad(fn, args)
return wrapped
def _recompute_grad(fn, args):
"""See recompute_grad."""
cached_vs = []
cached_arg_scope = []
def grad_fn(inputs, variables, outputs, output_grads):
"""Recompute outputs for gradient computation."""
del outputs
variables = [underlying_variable_ref(v) for v in variables]
# Recompute outputs
with tf.control_dependencies(output_grads):
with contrib.framework().arg_scope(cached_arg_scope[0]):
with tf.variable_scope(cached_vs[0], reuse=True):
outputs = fn(*inputs)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
outputs = list(outputs)
grads = tf.gradients(outputs, inputs + variables, output_grads)
grad_inputs = grads[:len(inputs)]
grad_vars = grads[len(inputs):]
# TODO(rsepassi): Make fn_with_custom_grad work with bfloat16.
# If the input gradients are bfloat16, it's assumed the variables are
# bfloat16. This is a hack to ensure that grad_vars are the right type.
if grad_inputs[0].dtype == tf.bfloat16:
grad_vars = [tf.cast(grad_var, tf.bfloat16) for grad_var in grad_vars]
return grad_inputs, grad_vars
@fn_with_custom_grad(grad_fn)
def fn_with_recompute(*args):
cached_vs.append(tf.get_variable_scope())
cached_arg_scope.append(contrib.framework().current_arg_scope())
return fn(*args)
return fn_with_recompute(*args)
def dense(x, units, **kwargs):
"""Identical to layers.dense."""
layer_collection = kwargs.pop("layer_collection", None)
activations = layers().Dense(units, **kwargs)(x)
if layer_collection:
# We need to find the layer parameters using scope name for the layer, so
# check that the layer is named. Otherwise parameters for different layers
# may get mixed up.
layer_name = tf.get_variable_scope().name
if (not layer_name) or ("name" not in kwargs):
raise ValueError(
"Variable scope and layer name cannot be empty. Actual: "
"variable_scope={}, layer name={}".format(
layer_name, kwargs.get("name", None)))
layer_name += "/" + kwargs["name"]
layer_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope=layer_name)
assert layer_params
if len(layer_params) == 1:
layer_params = layer_params[0]
tf.logging.info(
"Registering dense layer to collection for tensor: {}".format(
layer_params))
x_shape = x.shape.as_list()
if len(x_shape) == 3:
# Handle [batch, time, depth] inputs by folding batch and time into
# one dimension: reshaping inputs to [batchxtime, depth].
x_2d = tf.reshape(x, [-1, x_shape[2]])
activations_shape = activations.shape.as_list()
activations_2d = tf.reshape(activations, [-1, activations_shape[2]])
layer_collection.register_fully_connected_multi(
layer_params, x_2d, activations_2d, num_uses=x_shape[1])
activations = tf.reshape(activations_2d, activations_shape)
else:
layer_collection.register_fully_connected(layer_params, x, activations)
return activations
def batch_dense(inputs,
units,
activation=None,
kernel_initializer=None,
reuse=None,
name=None):
"""Multiply a batch of input matrices by a batch of parameter matrices.
Each input matrix is multiplied by the corresponding parameter matrix.
This is useful in a mixture-of-experts where the batch represents different
experts with different inputs.
Args:
inputs: a Tensor with shape [batch, length, input_units]
units: an integer
activation: an optional activation function to apply to the output
kernel_initializer: an optional initializer
reuse: whether to reuse the varaible scope
name: an optional string
Returns:
a Tensor with shape [batch, length, units]
Raises:
ValueError: if the "batch" or "input_units" dimensions of inputs are not
statically known.
"""
inputs_shape = shape_list(inputs)
if len(inputs_shape) != 3:
raise ValueError("inputs must have 3 dimensions")
batch = inputs_shape[0]
input_units = inputs_shape[2]
if not isinstance(batch, int) or not isinstance(input_units, int):
raise ValueError("inputs must have static dimensions 0 and 2")
with tf.variable_scope(
name,
default_name="batch_dense",
values=[inputs],
reuse=reuse,
dtype=inputs.dtype):
if kernel_initializer is None:
kernel_initializer = tf.random_normal_initializer(
stddev=input_units**-0.5)
w = tf.get_variable(
"w", [batch, input_units, units],
initializer=kernel_initializer,
dtype=inputs.dtype)
y = tf.matmul(inputs, w)
if activation is not None:
y = activation(y)
return y
def mix(x1,
x2,
steps,
is_training,
min_prob=0.0,
max_prob=1.0,
mode="lin",
simple=False,
broadcast_last=False):
"""Mix starting with x2, mixing mixing, going towards x1."""
with tf.name_scope("mix"):
if not is_training:
if max_prob >= 1.0:
return x1
alpha_shape = shape_list(x1)
if broadcast_last:
alpha_shape = alpha_shape[:-1] + [1]
alpha = tf.random_uniform(alpha_shape)
alpha = to_float(tf.less(alpha, max_prob))
return alpha * x1 + (1.0 - alpha) * x2
def get_res():
"""Create the result.
Separate function to speed it up later (see below).
Returns:
Tensor of mixed inputs.
"""
if mode == "lin":
alpha_p = inverse_lin_decay(steps)
else:
alpha_p = inverse_exp_decay(steps)
alpha_p = alpha_p * (max_prob - min_prob) + min_prob
if simple:
return alpha_p * x1 + (1.0 - alpha_p) * x2
alpha_shape = shape_list(x1)
if broadcast_last:
alpha_shape = alpha_shape[:-1] + [1]
alpha = tf.random_uniform(alpha_shape)
alpha = to_float(tf.less(alpha, alpha_p))
return alpha * x1 + (1.0 - alpha) * x2
if max_prob < 1.0:
return get_res()
# Prevent sampling after steps is passed to speed it up.
if is_xla_compiled():
return get_res()
else:
cur_step = tf.train.get_global_step()
if cur_step is None:
return x1 # Step not available, probably eval mode, don't mix.
return tf.cond(tf.less(cur_step, steps), get_res, lambda: x1)
def brelu(x):
"""Bipolar ReLU as in https://arxiv.org/abs/1709.04054."""
x_shape = shape_list(x)
x1, x2 = tf.split(tf.reshape(x, x_shape[:-1] + [-1, 2]), 2, axis=-1)
y1 = tf.nn.relu(x1)
y2 = -tf.nn.relu(-x2)
return tf.reshape(tf.concat([y1, y2], axis=-1), x_shape)
def belu(x):
"""Bipolar ELU as in https://arxiv.org/abs/1709.04054."""
x_shape = shape_list(x)
x1, x2 = tf.split(tf.reshape(x, x_shape[:-1] + [-1, 2]), 2, axis=-1)
y1 = tf.nn.elu(x1)
y2 = -tf.nn.elu(-x2)
return tf.reshape(tf.concat([y1, y2], axis=-1), x_shape)
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
x with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def nac(x, depth, name=None, reuse=None):
"""NAC as in https://arxiv.org/abs/1808.00508."""
with tf.variable_scope(name, default_name="nac", values=[x], reuse=reuse):
x_shape = shape_list(x)
w = tf.get_variable("w", [x_shape[-1], depth])
m = tf.get_variable("m", [x_shape[-1], depth])
w = tf.tanh(w) * tf.nn.sigmoid(m)
x_flat = tf.reshape(x, [-1, x_shape[-1]])
res_flat = tf.matmul(x_flat, w)
return tf.reshape(res_flat, x_shape[:-1] + [depth])
def nalu(x, depth, epsilon=1e-30, name=None, reuse=None):
"""NALU as in https://arxiv.org/abs/1808.00508."""
with tf.variable_scope(name, default_name="nalu", values=[x], reuse=reuse):
x_shape = shape_list(x)
x_flat = tf.reshape(x, [-1, x_shape[-1]])
gw = tf.get_variable("w", [x_shape[-1], depth])
g = tf.nn.sigmoid(tf.matmul(x_flat, gw))
g = tf.reshape(g, x_shape[:-1] + [depth])
a = nac(x, depth, name="nac_lin")
log_x = tf.log(tf.abs(x) + epsilon)
m = nac(log_x, depth, name="nac_log")
return g * a + (1 - g) * tf.exp(m)
def argmax_with_score(logits, axis=None):
"""Argmax along with the value."""
axis = axis or len(logits.get_shape()) - 1
predictions = tf.argmax(logits, axis=axis)
logits_shape = shape_list(logits)
prefix_shape, vocab_size = logits_shape[:-1], logits_shape[-1]
prefix_size = 1
for d in prefix_shape:
prefix_size *= d
# Flatten to extract scores
flat_logits = tf.reshape(logits, [prefix_size, vocab_size])
flat_predictions = tf.reshape(predictions, [prefix_size])
flat_indices = tf.stack(
[tf.range(tf.to_int64(prefix_size)),
tf.to_int64(flat_predictions)],
axis=1)
flat_scores = tf.gather_nd(flat_logits, flat_indices)
# Unflatten
scores = tf.reshape(flat_scores, prefix_shape)
return predictions, scores
def log_prob_from_logits(logits, reduce_axis=-1):
return logits - tf.reduce_logsumexp(logits, axis=reduce_axis, keepdims=True)
def top_kth_iterative(x, k):
"""Compute the k-th top element of x on the last axis iteratively.
This assumes values in x are non-negative, rescale if needed.
It is often faster than tf.nn.top_k for small k, especially if k < 30.
Note: this does not support back-propagation, it stops gradients!
Args:
x: a Tensor of non-negative numbers of type float.
k: a python integer.
Returns:
a float tensor of the same shape as x but with 1 on the last axis
that contains the k-th largest number in x.
"""
# The iterative computation is as follows:
#
# cur_x = x
# for _ in range(k):
# top_x = maximum of elements of cur_x on the last axis
# cur_x = cur_x where cur_x < top_x and 0 everywhere else (top elements)
#
# We encode this computation in a TF graph using tf.foldl, so the inner
# part of the above loop is called "next_x" and tf.foldl does the loop.
def next_x(cur_x, _):
top_x = tf.reduce_max(cur_x, axis=-1, keep_dims=True)
return cur_x * to_float(cur_x < top_x)
# We only do k-1 steps of the loop and compute the final max separately.
fin_x = tf.foldl(next_x, tf.range(k - 1), initializer=tf.stop_gradient(x),
parallel_iterations=2, back_prop=False)
return tf.stop_gradient(tf.reduce_max(fin_x, axis=-1, keep_dims=True))
def top_1_tpu(inputs):
"""find max and argmax over the last dimension.
Works well on TPU
Args:
inputs: A tensor with shape [..., depth]
Returns:
values: a Tensor with shape [...]
indices: a Tensor with shape [...]
"""
inputs_max = tf.reduce_max(inputs, axis=-1, keepdims=True)
mask = tf.to_int32(tf.equal(inputs_max, inputs))
index = tf.range(tf.shape(inputs)[-1]) * mask
return tf.squeeze(inputs_max, -1), tf.reduce_max(index, axis=-1)
def index_last_dim_with_indices(x, indices):
"""Use indices to index into the last axis of x.
This can be useful for recovering the actual probabilities of a sample from a
probability distribution.
Args:
x: Tensor, n-d.
indices: Tensor, (n-1)-d, where the dimension sizes match the first (n-1)
dimensions of x. The values of indices will be used to index into the last
axis of x.
Returns:
Tensor, (n-1)-d.
"""
assert len(x.shape) == len(indices.shape) + 1
x_shape = shape_list(x)
vocab_size = x_shape[-1]
flat_x = tf.reshape(x, [list_product(x_shape[:-1]), vocab_size])
flat_indices = tf.reshape(indices, [list_product(x_shape[:-1])])
idx = tf.stack(
[
tf.range(tf.to_int64(shape_list(flat_indices)[0])),
tf.to_int64(flat_indices)
],
axis=1)
flat_x_idx = tf.gather_nd(flat_x, idx)
x_idx = tf.reshape(flat_x_idx, x_shape[:-1])
return x_idx
def should_generate_summaries():
"""Is this an appropriate context to generate summaries.
Returns:
a boolean
"""
name_scope = contrib.framework().get_name_scope()
if name_scope and "while/" in name_scope:
# Summaries don't work well within tf.while_loop()
return False
if tf.get_variable_scope().reuse:
# Avoid generating separate summaries for different data shards
return False
return True
def reshape_like(a, b):
"""Reshapes a to match the shape of b in all but the last dimension."""
ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0))
if not tf.executing_eagerly():
ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:])
return ret
def summarize_video(video, prefix, max_outputs=1):
"""Summarize the video using image summaries starting with prefix."""
video_shape = shape_list(video)
if len(video_shape) != 5:
raise ValueError("Assuming videos given as tensors in the format "
"[batch, time, height, width, channels] but got one "
"of shape: %s" % str(video_shape))
if tf.executing_eagerly():
return
if video.get_shape().as_list()[1] is None:
tf.summary.image(
"%s_last_frame" % prefix,
tf.cast(video[:, -1, :, :, :], tf.uint8),
max_outputs=max_outputs)
else:
for k in range(video_shape[1]):
tf.summary.image(
"%s_frame_%d" % (prefix, k),
tf.cast(video[:, k, :, :, :], tf.uint8),
max_outputs=max_outputs)
def cast_like(x, y):
"""Cast x to y's dtype, if necessary."""
x = tf.convert_to_tensor(x)
y = tf.convert_to_tensor(y)
if x.dtype.base_dtype == y.dtype.base_dtype:
return x
cast_x = tf.cast(x, y.dtype)
if cast_x.device != x.device:
x_name = "(eager Tensor)"
try:
x_name = x.name
except AttributeError:
pass
tf.logging.warning("Cast for %s may induce copy from '%s' to '%s'", x_name,
x.device, cast_x.device)
return cast_x
def make_even_size(x):
"""Pad x to be even-sized on axis 1 and 2, but only if necessary."""
x_shape = x.get_shape().as_list()
assert len(x_shape) > 2, "Only 3+-dimensional tensors supported."
shape = [dim if dim is not None else -1 for dim in x_shape]
new_shape = x_shape # To make sure constant shapes remain constant.
if x_shape[1] is not None:
new_shape[1] = 2 * int(math.ceil(x_shape[1] * 0.5))
if x_shape[2] is not None:
new_shape[2] = 2 * int(math.ceil(x_shape[2] * 0.5))
if shape[1] % 2 == 0 and shape[2] % 2 == 0:
return x
if shape[1] % 2 == 0:
x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=2)
x.set_shape(new_shape)
return x
if shape[2] % 2 == 0:
x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=1)
x.set_shape(new_shape)
return x
x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=1)
x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=2)
x.set_shape(new_shape)
return x
def sliced_gan_loss(input1,
input2,
discriminator,
num_vecs,
do_random_vecs=True,
do_tanh=True,
return_logits=False):
"""Loss inspired by the sliced WGAN paper: https://arxiv.org/abs/1804.01947.
Puts input1 and input2 through the provided discriminator to get logits.
Then, computes num_vecs random projections of the logits, sorts them on
the batch dimension and returns the L2 loss between the sorted vectors.
See the above-mentioned paper for the reasoning behind it.
Args:
input1: first discriminator inputs.
input2: second discriminator inputs.
discriminator: inputs -> logits function.
num_vecs: how many random vectors to use for projections.
do_random_vecs: whether to use random vectors or just tanh of the logits.
do_tanh: if true (default) we'll also just use tanh of the logits.
return_logits: Whether or not to return the logits.
Returns:
The generator loss, i.e., the sliced approximation of the distance between
the projected distributions (warning: discriminator should maximize it).
"""
with tf.variable_scope("sliced_gan"):
with tf.variable_scope("discriminator"):
logits1 = discriminator(input1)
with tf.variable_scope("discriminator", reuse=True):
logits2 = discriminator(input2)
if do_random_vecs:
random_vecs = tf.nn.l2_normalize(
tf.random_uniform([shape_list(logits1)[-1], num_vecs]), axis=0)
def get_sorted_projections(x):
"""Make projections of x and sort them on the batch dimension."""
x = tf.reshape(x, [-1, shape_list(x)[-1]])
batch_size = shape_list(x)[0]
if do_random_vecs and do_tanh:
n = tf.nn.l2_normalize(x, axis=1)
proj = tf.concat([tf.matmul(n, random_vecs), tf.tanh(n)], axis=1)
elif do_random_vecs:
n = tf.nn.l2_normalize(x, axis=1)
proj = tf.matmul(n, random_vecs)
else:
proj = tf.tanh(x)
proj = tf.transpose(proj, [1, 0]) # [num_vecs, batch] after this.
if is_xla_compiled():
proj_dtype = proj.dtype
proj = tf.cast(proj, tf.bfloat16)
# Currently TPU only supports 1-D top_k calls.
map_fn = lambda x: tf.nn.top_k(x, k=batch_size, sorted=True)[0]
values = tf.map_fn(map_fn, proj)
values = tf.cast(values, proj_dtype)
else:
values, _ = tf.nn.top_k(proj, k=batch_size, sorted=True)
return values
proj1 = get_sorted_projections(logits1)
proj2 = get_sorted_projections(logits2)
dist = tf.reduce_mean(tf.squared_difference(proj1, proj2))
if return_logits:
return dist, logits1, logits2
return dist
def lrelu(input_, leak=0.2, name="lrelu"):
return tf.maximum(input_, leak * input_, name=name)
def deep_discriminator(x,
batch_norm,
is_training,
filters=64,
filter_size=4,
stride=2,
output_size=1024):
"""Discriminator architecture based on InfoGAN."""
with tf.variable_scope(
"discriminator", initializer=tf.random_normal_initializer(stddev=0.02)):
batch_size, height, width = shape_list(x)[:3] # pylint: disable=unbalanced-tuple-unpacking
net = layers().Conv2D(
filters, filter_size, strides=stride, padding="SAME", name="conv1")(x)
net = lrelu(net)
net = layers().Conv2D(
2 * filters,
filter_size,
strides=stride,
padding="SAME",
name="conv2")(net)
# [bs, h/4, w/4, 128]
if batch_norm:
net = layers().BatchNormalization(
training=is_training, momentum=0.999, name="d_bn2")(net)
net = lrelu(net)
size = height * width
x_shape = x.get_shape().as_list()
if x_shape[1] is None or x_shape[2] is None:
net = tf.reduce_mean(net, axis=[1, 2])
else:
net = tf.reshape(net, [batch_size, size * 8])
net = layers().Dense(output_size, name="d_fc3")(net)
if batch_norm:
net = layers().BatchNormalization(
training=is_training, momentum=0.999, name="d_bn3")(net)
net = lrelu(net)
return net
def instance_norm(x):
"""Instance normalization layer."""
with tf.variable_scope("instance_norm"):
epsilon = 1e-5
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
scale = tf.get_variable(
"scale", [x.get_shape()[-1]],
initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))
offset = tf.get_variable(
"offset", [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0))
out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
return out
def general_conv(x,
num_filters=64,
filter_size=7,
stride=1,
stddev=0.02,
padding="VALID",
name="conv",
do_norm="instance",
do_relu=True,
relufactor=0):
"""Generalized convolution layer."""
with tf.variable_scope(name):
x = layers().Conv2D(
num_filters,
filter_size,
stride,
padding,
activation=None,
kernel_initializer=tf.truncated_normal_initializer(stddev=stddev),
bias_initializer=tf.constant_initializer(0.0))(x)
if do_norm == "layer":
x = layer_norm(x)
elif do_norm == "instance":
x = instance_norm(x)
if do_relu:
if relufactor == 0:
x = tf.nn.relu(x, "relu")
else:
x = lrelu(x, leak=relufactor)
return x
def patch_discriminator(x, filters=64, filter_size=5, n=4,
name="patch_discrim"):
"""Patch descriminator."""
with tf.variable_scope(name):
x_shape = shape_list(x)
spatial_dims = [x_shape[1] // 4, x_shape[2] // 4]
x = tf.random_crop(x, [x_shape[0]] + spatial_dims + [x_shape[3]])
for i in range(n):
x = general_conv(
x=x,
num_filters=filters * 2**i,
filter_size=filter_size,
stride=2 if i != n - 1 else 1,
stddev=0.02,
padding="SAME",
name="c%d" % i,
do_norm="instance" if i != 0 else False,
do_relu=i != n - 1,
relufactor=0.2)
x = tf.reduce_mean(x, [1, 2])
return x
def mean_with_attention(x, name, num_heads=4):
"""Mean and attention to reduce spatial dimensions."""
with tf.variable_scope(name):
shape = shape_list(x)
m = tf.reduce_mean(x, [1, 2])
a = layers().Dense(num_heads, name="mean_attn")(x)
s = tf.reshape(a, [shape[0], -1, num_heads])
s = tf.nn.softmax(s, axis=1)
s = tf.reshape(s, shape[:-1] + [1, num_heads])
am = tf.reduce_mean(tf.expand_dims(x, axis=-1) * s, [1, 2])
l = tf.concat([am, tf.expand_dims(m, axis=-1)], axis=-1)
return layers().Dense(2 * shape[-1], name="mean_attn_final")(
tf.reshape(l, [shape[0], (num_heads+1) * shape[-1]]))
def single_discriminator(x, filters=128, kernel_size=8,
strides=4, pure_mean=False):
"""A simple single-layer convolutional discriminator."""
with tf.variable_scope("discriminator"):
net = layers().Conv2D(
filters, kernel_size, strides=strides, padding="SAME", name="conv1")(x)
if pure_mean:
net = tf.reduce_mean(net, [1, 2])
else:
net = mean_with_attention(net, "mean_with_attention")
return net
def double_discriminator(x, filters1=128, filters2=None,
kernel_size=8, strides=4, pure_mean=False):
"""A convolutional discriminator with 2 layers and concatenated output."""
if filters2 is None:
filters2 = 4 * filters1
with tf.variable_scope("discriminator"):
batch_size = shape_list(x)[0]
net = layers().Conv2D(
filters1, kernel_size, strides=strides, padding="SAME", name="conv1")(x)
if pure_mean:
net1 = tf.reduce_mean(net, [1, 2])
else:
net1 = mean_with_attention(net, "mean_with_attention1")
tf.reshape(net, [batch_size, -1])
net = tf.nn.relu(net)
net = layers().Conv2D(
filters2, kernel_size, strides=strides, padding="SAME",
name="conv2")(net)
if pure_mean:
net2 = tf.reduce_mean(net, [1, 2])
else:
net2 = mean_with_attention(net, "mean_with_attention2")
return tf.concat([net1, net2], axis=-1)
def upscale(inputs, f, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR):
"""Upscaling the image by a factor of f."""
height, width = shape_list(inputs)[1:3] # pylint: disable=unbalanced-tuple-unpacking
return tf.image.resize_images(inputs, (height * f, width * f), method)
def tpu_safe_image_summary(image):
if is_xla_compiled():
# We only support float32 images at the moment due to casting complications.
if image.dtype != tf.float32:
image = to_float(image)
else:
image = tf.cast(image, tf.uint8)
return image
# This has been (shamefully) copied from
# GitHub tensorflow/models/blob/master/research/slim/nets/cyclegan.py
#
# tensorflow/models cannot be pip installed, and even if it were we don't want
# to depend on all the models in it.
#
# Therefore copying and forgoing any more bugfixes into it is the most
# expedient way to use this function.
def cyclegan_upsample(net, num_outputs, stride, method="conv2d_transpose"):
"""Upsamples the given inputs.
Args:
net: A Tensor of size [batch_size, height, width, filters].
num_outputs: The number of output filters.
stride: A list of 2 scalars or a 1x2 Tensor indicating the scale,
relative to the inputs, of the output dimensions. For example, if kernel
size is [2, 3], then the output height and width will be twice and three
times the input size.
method: The upsampling method: 'nn_upsample_conv',
'bilinear_upsample_conv', or 'conv2d_transpose'.
Returns:
A Tensor which was upsampled using the specified method.
Raises:
ValueError: if `method` is not recognized.
"""
with tf.variable_scope("upconv"):
net_shape = tf.shape(net)
height = net_shape[1]
width = net_shape[2]
# Reflection pad by 1 in spatial dimensions (axes 1, 2 = h, w) to make a
# 3x3 "valid" convolution produce an output with the same dimension as the
# input.
spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]])
if method == "nn_upsample_conv":
net = tf.image.resize_nearest_neighbor(
net, [stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, "REFLECT")
net = layers().Conv2D(
num_outputs, (3, 3), activation=tf.nn.relu)(net)
elif method == "bilinear_upsample_conv":
net = tf.image.resize_bilinear(net,
[stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, "REFLECT")
net = layers().Conv2D(
num_outputs, (3, 3), activation=tf.nn.relu)(net)
elif method == "conv2d_transpose":
# This corrects 1 pixel offset for images with even width and height.
# conv2d is left aligned and conv2d_transpose is right aligned for even
# sized images (while doing "SAME" padding).
# Note: This doesn"t reflect actual model in paper.
net = layers().Conv2DTranspose(
num_outputs, (3, 3), strides=stride, activation=tf.nn.relu)(net)
net = net[:, 1:, 1:, :]
else:
raise ValueError("Unknown method: [%s]" % method)
return net
def weight_targeting(w, k):
"""Weight-level magnitude pruning."""
k = tf.to_int32(k)
w_shape = shape_list(w)
size = tf.to_int32(tf.reduce_prod(w_shape[:-1]))
w = tf.reshape(w, [size, w_shape[-1]])
transpose_w = tf.transpose(w)
thres = contrib.framework().sort(tf.abs(transpose_w), axis=1)[:, k]
mask = to_float(thres[None, :] >= tf.abs(w))
return tf.reshape(mask, w_shape)
def unit_targeting(w, k):
"""Unit-level magnitude pruning."""
k = tf.to_int32(k)
w_shape = shape_list(w)
size = tf.to_int32(tf.reduce_prod(w_shape[:-1]))
w = tf.reshape(w, [size, w_shape[-1]])
norm = tf.norm(w, axis=0)
thres = contrib.framework().sort(norm, axis=0)[k]
mask = to_float(thres >= norm)[None, :]
mask = tf.tile(mask, [size, 1])
return tf.reshape(mask, w_shape)
def td_conv(inputs,
filters,
kernel_size,
targeting_count,
targeting_fn,
keep_prob,
is_training,
do_prune=True,
strides=(1, 1),
padding="valid",
data_format="channels_last",
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=tf.zeros_initializer(),
name=None,
reuse=None):
"""Apply targeted dropout to the weights of a convolution."""
with tf.variable_scope(name, default_name="td_conv", reuse=reuse):
nhwc = data_format == "channels_last"
in_dim = shape_list(inputs)[-1] if nhwc else shape_list(inputs)[1]
kernel_shape = [kernel_size, kernel_size, in_dim, filters]
w = tf.get_variable(
"DW", shape=kernel_shape, initializer=kernel_initializer)
if use_bias:
b = tf.get_variable("b", shape=[filters], initializer=bias_initializer)
if keep_prob < 1.0:
w = targeted_dropout(
w,
targeting_count,
keep_prob,
targeting_fn,
is_training,
do_prune=do_prune)
if isinstance(strides, int):
strides = [strides, strides]
if isinstance(dilation_rate, int):
dilation_rate = [dilation_rate, dilation_rate]
if nhwc:
strides = [1, strides[0], strides[1], 1]
dilation_rate = [1, dilation_rate[0], dilation_rate[1], 1]
else:
strides = [1, 1, strides[0], strides[1]]
dilation_rate = [1, 1, dilation_rate[0], dilation_rate[1]]
y = tf.nn.conv2d(
inputs,
w,
strides,
padding,
data_format="NHWC" if nhwc else "NCHW",
dilations=dilation_rate,
name=None)
if use_bias:
y += b
if activation:
y = activation(y)
return y
def targeted_dropout(inputs,
k,
keep_prob,
targeting_fn,
is_training,
do_prune=False):
"""Applies targeted dropout.
Applies dropout at a rate of `1 - keep_prob` to only those elements of
`inputs` marked by `targeting_fn`. See below and paper for more detail:
"Targeted Dropout for Posthoc Pruning" Aidan N. Gomez, Ivan Zhang,
Kevin Swersky, Yarin Gal, and Geoffrey E. Hinton.
Args:
inputs: Tensor, inputs to apply targeted dropout to.
k: Scalar Tensor or python scalar, sets the number of elements to target in
`inputs`. Must be within `[0, tf.shape(x)[-1]]` and compatible with
second argument of `targeting_fn`.
keep_prob: Scalar Tensor, passed as `tf.nn.dropout`'s `keep_prob` argument.
targeting_fn: callable `fn(inputs, k) -> Boolean Tensor`, produces a
boolean mask the same shape as `inputs` where True indicates an element
will be dropped, and False not.
is_training: bool, indicates whether currently training.
do_prune: bool, indicates whether to prune the `k * (1 - keep_prob)`
elements of `inputs` expected to be dropped each forwards pass.
Returns:
Tensor, same shape and dtype as `inputs`.
"""
if not is_training and do_prune:
k = tf.round(to_float(k) * to_float(1. - keep_prob))
mask = targeting_fn(inputs, k)
mask = tf.cast(mask, inputs.dtype)
if is_training:
return inputs * (1 - mask) + tf.nn.dropout(inputs, keep_prob) * mask
elif do_prune:
return inputs * (1 - mask)
else:
return inputs
def kl_divergence(mu, log_var, mu_p=0.0, log_var_p=0.0):
"""KL divergence of diagonal gaussian N(mu,exp(log_var)) and N(0,1).
Args:
mu: mu parameter of the distribution.
log_var: log(var) parameter of the distribution.
mu_p: optional mu from a learned prior distribution
log_var_p: optional log(var) from a learned prior distribution
Returns:
the KL loss.
"""
batch_size = shape_list(mu)[0]
prior_distribution = tfp.distributions.Normal(
mu_p, tf.exp(tf.multiply(0.5, log_var_p)))
posterior_distribution = tfp.distributions.Normal(
mu, tf.exp(tf.multiply(0.5, log_var)))
kld = tfp.distributions.kl_divergence(posterior_distribution,
prior_distribution)
return tf.reduce_sum(kld) / to_float(batch_size)
def sparse_equals_constant(constant, tensor):
return tf.SparseTensor(
indices=tensor.indices,
dense_shape=tensor.dense_shape,
values=tf.equal(tensor.values, constant))
def sparse_expand_dims(tensor, current_num_dims, axis=0):
if axis == -1:
axis = current_num_dims
new_col = tf.zeros([tf.shape(tensor.indices)[0]], dtype=tf.int64)
cols = tf.unstack(tensor.indices, axis=1, num=current_num_dims)
shape = tf.unstack(tensor.dense_shape, num=current_num_dims)
new_indices = tf.stack(cols[:axis] + [new_col] + cols[axis:], axis=1)
return tf.SparseTensor(
indices=new_indices,
values=tensor.values,
dense_shape=tf.stack(shape[:axis] + [1] + shape[axis:]))
def sparse_add_constant(constant, tensor):
return tf.SparseTensor(
indices=tensor.indices,
values=constant + tensor.values,
dense_shape=tensor.dense_shape)
def sparse_eye(size):
indices = tf.cast(tf.stack([tf.range(size), tf.range(size)]), tf.int64)
values = tf.ones(size)
dense_shape = [tf.cast(size, tf.int64), tf.cast(size, tf.int64)]
return tf.SparseTensor(
indices=indices, values=values, dense_shape=dense_shape)
# modification from https://github.com/tensorflow/tensorflow/pull/21276
# without special initialization for g
class WeightNorm(tf.keras.layers.Wrapper):
"""Decouple weight magnitude and direction.
This wrapper reparameterizes a layer by decoupling the weight's
magnitude and direction. This speeds up convergence by improving the
conditioning of the optimization problem.
Weight Normalization: A Simple Reparameterization to Accelerate
Training of Deep Neural Networks: https://arxiv.org/abs/1602.07868
Tim Salimans, Diederik P. Kingma (2016)
WeightNorm wrapper works for keras and tf layers.
```python
net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),
input_shape=(32, 32, 3), data_init=True)(x)
net = WeightNorm(tf.keras.layers.Conv2D(16, 5, activation='relu'),
data_init=True)
net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),
data_init=True)(net)
net = WeightNorm(tf.keras.layers.Dense(n_classes),
data_init=True)(net)
```
Arguments:
layer: a layer instance.
data_init: If `True` use data dependent variable initialization
Raises:
ValueError: If not initialized with a `Layer` instance.
ValueError: If `Layer` does not contain a `kernel` of weights
NotImplementedError: If `data_init` is True and running graph execution
"""
def __init__(self, layer, data_init=False, **kwargs):
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError(
"Please initialize `WeightNorm` layer with a "
"`Layer` instance. You passed: {input}".format(input=layer))
super(WeightNorm, self).__init__(layer, **kwargs)
self._track_trackable(layer, name="layer")
def _compute_weights(self):
"""Generate weights with normalization."""
with tf.variable_scope("compute_weights"):
self.layer.kernel = tf.nn.l2_normalize(
self.layer.v, axis=self.norm_axes) * self.layer.g
def _init_norm(self, weights):
"""Set the norm of the weight vector."""
with tf.variable_scope("init_norm"):
flat = tf.reshape(weights, [-1, self.layer_depth])
return tf.reshape(tf.norm(flat, axis=0), (self.layer_depth,))
def _data_dep_init(self, inputs):
"""Data dependent initialization for eager execution."""
with tf.variable_scope("data_dep_init"):
# Generate data dependent init values
activation = self.layer.activation
self.layer.activation = None
x_init = self.layer.call(inputs)
m_init, v_init = tf.moments(x_init, self.norm_axes)
scale_init = 1. / tf.sqrt(v_init + 1e-10)
# Assign data dependent init values
self.layer.g = self.layer.g * scale_init
self.layer.bias = (-m_init * scale_init)
self.layer.activation = activation
self.initialized = True
def build(self, input_shape=None):
"""Build `Layer`."""
if not self.layer.built:
self.layer.build(input_shape)
self.layer.built = False
if not hasattr(self.layer, "kernel"):
raise ValueError("`WeightNorm` must wrap a layer that"
" contains a `kernel` for weights")
# The kernel's filter or unit dimension is -1
self.layer_depth = int(self.layer.kernel.shape[-1])
self.norm_axes = list(range(self.layer.kernel.shape.ndims - 1))
self.layer.v = self.layer.kernel
self.layer.g = self.layer.add_variable(
name="g",
shape=(self.layer_depth,),
initializer=tf.ones_initializer,
dtype=self.layer.kernel.dtype,
trainable=True)
# with ops.control_dependencies([self.layer.g.assign(
# self._init_norm(self.layer.v))]):
# self._compute_weights()
self._compute_weights()
self.layer.built = True
self.input_spec = self.layer.input_spec
super(WeightNorm, self).build()
self.built = True
def call(self, inputs):
"""Call `Layer`."""
# if context.executing_eagerly():
# if not self.initialized:
# self._data_dep_init(inputs)
self._compute_weights() # Recompute weights for each forward pass
output = self.layer.call(inputs)
return output
def compute_output_shape(self, input_shape):
return tf.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
================================================
FILE: tensor2tensor/layers/common_layers_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for common layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import kfac
import numpy as np
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
tf.enable_eager_execution()
class CommonLayersTest(parameterized.TestCase, tf.test.TestCase):
@test_utils.run_in_graph_and_eager_modes()
def testIndexLastDimWithIndices(self):
x = np.array([[2., 3., 4., 5.],
[6., 7., 8., 9.]])
indices = np.array([2, 0])
x_idx = common_layers.index_last_dim_with_indices(x, indices)
expected = np.array([4., 6.])
self.assertAllEqual(expected, self.evaluate(x_idx))
@test_utils.run_in_graph_and_eager_modes()
def testSaturatingSigmoid(self):
x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32)
y = common_layers.saturating_sigmoid(tf.constant(x))
res = self.evaluate(y)
self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
@test_utils.run_in_graph_and_eager_modes()
def testFlatten4D3D(self):
x = np.random.randint(1, high=9, size=(3, 5, 2))
y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7))
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (3, 5 * 2, 7))
@test_utils.run_in_graph_and_eager_modes()
def testEmbedding(self):
x = np.random.randint(1, high=9, size=(3, 5))
y = common_layers.embedding(x, 10, 16)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (3, 5, 16))
@test_utils.run_in_graph_mode_only()
def testShakeShake(self):
x = np.random.rand(5, 7)
with self.test_session() as session:
x = tf.constant(x, dtype=tf.float32)
y = common_layers.shakeshake([x, x, x, x, x])
inp, res = session.run([x, y])
self.assertAllClose(res, inp)
@test_utils.run_in_graph_and_eager_modes()
def testConv(self):
x = np.random.rand(5, 7, 1, 11)
y = common_layers.conv(tf.constant(x, dtype=tf.float32), 13, (3, 1))
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 5, 1, 13))
@test_utils.run_in_graph_and_eager_modes()
def testConv1d(self):
x = np.random.rand(5, 7, 11)
y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 13))
@test_utils.run_in_graph_and_eager_modes()
def testSeparableConv(self):
x = np.random.rand(5, 7, 1, 11)
y = common_layers.separable_conv(
tf.constant(x, dtype=tf.float32), 13, (3, 1))
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 5, 1, 13))
@test_utils.run_in_graph_and_eager_modes()
def testSubSeparableConv(self):
for sep in [0, 1, 2, 4]:
x = np.random.rand(5, 7, 1, 12)
with tf.variable_scope("sep_%d" % sep):
y = common_layers.subseparable_conv(
tf.constant(x, dtype=tf.float32), 16, (3, 1), separability=sep)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 5, 1, 16))
@test_utils.run_in_graph_and_eager_modes()
def testConvBlock(self):
x = np.random.rand(5, 7, 1, 11)
y = common_layers.conv_block(
tf.constant(x, dtype=tf.float32),
13, [(1, (3, 3)), (1, (3, 3))],
padding="SAME",
normalizer_fn=common_layers.noam_norm)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 1, 13))
@test_utils.run_in_graph_and_eager_modes()
def testSeparableConvBlock(self):
x = np.random.rand(5, 7, 1, 11)
y = common_layers.separable_conv_block(
tf.constant(x, dtype=tf.float32),
13, [(1, (3, 3)), (1, (3, 3))],
padding="SAME")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 1, 13))
@test_utils.run_in_graph_and_eager_modes()
def testSubSeparableConvBlock(self):
for sep in [0, 1, 2, 4]:
x = np.random.rand(5, 7, 1, 12)
with tf.variable_scope("sep_%d" % sep):
y = common_layers.subseparable_conv_block(
tf.constant(x, dtype=tf.float32),
16, [(1, (3, 3)), (1, (3, 3))],
padding="SAME",
separability=sep)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 1, 16))
@test_utils.run_in_graph_and_eager_modes()
def testPool(self):
x = np.random.rand(5, 8, 1, 11)
y = common_layers.pool(
tf.constant(x, dtype=tf.float32), (2, 2), "AVG", "SAME")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 8, 1, 11))
@test_utils.run_in_graph_and_eager_modes()
def testConvBlockDownsample(self):
x = np.random.rand(5, 7, 1, 11)
y = common_layers.conv_block_downsample(
tf.constant(x, dtype=tf.float32), (3, 1), (2, 1), "SAME")
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 4, 1, 27))
@test_utils.run_in_graph_and_eager_modes()
def testGetTimingSignal(self):
length = 7
num_timescales = 10
a = common_layers.get_timing_signal(length, num_timescales=num_timescales)
res = self.evaluate(a)
self.assertEqual(res.shape, (length, 2 * num_timescales))
@test_utils.run_in_graph_and_eager_modes()
def testAddTimingSignal(self):
batch = 5
length = 7
height = 3
depth = 35
x = np.random.rand(batch, length, height, depth)
a = common_layers.add_timing_signal(tf.constant(x, dtype=tf.float32))
res = self.evaluate(a)
self.assertEqual(res.shape, (batch, length, height, depth))
@test_utils.run_in_graph_and_eager_modes()
def testConvGRU(self):
x = np.random.rand(5, 7, 3, 11)
y = common_layers.conv_gru(tf.constant(x, dtype=tf.float32), (1, 3), 11)
z = common_layers.conv_gru(
tf.constant(x, dtype=tf.float32), (1, 3), 11, padding="LEFT")
self.evaluate(tf.global_variables_initializer())
res1 = self.evaluate(y)
res2 = self.evaluate(z)
self.assertEqual(res1.shape, (5, 7, 3, 11))
self.assertEqual(res2.shape, (5, 7, 3, 11))
@test_utils.run_in_graph_mode_only
def testSRU(self):
x = np.random.rand(5, 7, 3, 11)
with self.test_session() as session:
y = common_layers.sru(tf.constant(x, dtype=tf.float32))
session.run(tf.global_variables_initializer())
res = session.run(y)
self.assertEqual(res.shape, (5, 7, 3, 11))
@test_utils.run_in_graph_and_eager_modes()
def testLayerNorm(self):
x = np.random.rand(5, 7, 11)
y = common_layers.layer_norm(tf.constant(x, dtype=tf.float32), 11)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 11))
@test_utils.run_in_graph_and_eager_modes()
def testGroupNorm(self):
x = np.random.rand(5, 7, 3, 16)
y = common_layers.group_norm(tf.constant(x, dtype=tf.float32))
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 3, 16))
@test_utils.run_in_graph_and_eager_modes()
def testConvLSTM(self):
x = np.random.rand(5, 7, 11, 13)
y = common_layers.conv_lstm(tf.constant(x, dtype=tf.float32), (1, 3), 13)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(y)
self.assertEqual(res.shape, (5, 7, 11, 13))
@test_utils.run_in_graph_and_eager_modes()
def testPadToSameLength(self):
x1 = np.random.rand(5, 7, 11)
x2 = np.random.rand(5, 9, 11)
a, b = common_layers.pad_to_same_length(
tf.constant(x1, dtype=tf.float32), tf.constant(x2, dtype=tf.float32))
c, d = common_layers.pad_to_same_length(
tf.constant(x1, dtype=tf.float32),
tf.constant(x2, dtype=tf.float32),
final_length_divisible_by=4)
res1, res2 = self.evaluate([a, b])
res1a, res2a = self.evaluate([c, d])
self.assertEqual(res1.shape, (5, 9, 11))
self.assertEqual(res2.shape, (5, 9, 11))
self.assertEqual(res1a.shape, (5, 12, 11))
self.assertEqual(res2a.shape, (5, 12, 11))
@test_utils.run_in_graph_and_eager_modes()
def testShiftLeft(self):
x1 = np.zeros((5, 7, 1, 11))
x1[:, 0, :] = np.ones_like(x1[:, 0, :])
expected = np.zeros((5, 7, 1, 11))
expected[:, 1, :] = np.ones_like(expected[:, 1, :])
a = common_layers.shift_right(tf.constant(x1, dtype=tf.float32))
actual = self.evaluate(a)
self.assertAllEqual(actual, expected)
@test_utils.run_in_graph_and_eager_modes()
def testConvStride2MultiStep(self):
x1 = np.random.rand(5, 32, 16, 11)
a = common_layers.conv_stride2_multistep(
tf.constant(x1, dtype=tf.float32), 4, 16)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(a[0])
self.assertEqual(actual.shape, (5, 2, 1, 16))
@test_utils.run_in_graph_and_eager_modes()
def testDeconvStride2MultiStep(self):
x1 = np.random.rand(5, 2, 1, 11)
a = common_layers.deconv_stride2_multistep(
tf.constant(x1, dtype=tf.float32), 4, 16)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(a)
self.assertEqual(actual.shape, (5, 32, 1, 16))
@test_utils.run_in_graph_and_eager_modes()
def testApplyNormLayer(self):
x1 = np.random.rand(5, 2, 1, 11)
x2 = common_layers.apply_norm(
tf.constant(x1, dtype=tf.float32), "layer", depth=11, epsilon=1e-6)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(x2)
self.assertEqual(actual.shape, (5, 2, 1, 11))
@test_utils.run_in_graph_and_eager_modes()
def testApplyNormNoam(self):
x1 = np.random.rand(5, 2, 1, 11)
x2 = common_layers.apply_norm(
tf.constant(x1, dtype=tf.float32), "noam", depth=11, epsilon=1e-6)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(x2)
self.assertEqual(actual.shape, (5, 2, 1, 11))
@test_utils.run_in_graph_and_eager_modes()
def testApplyNormBatch(self):
x1 = np.random.rand(5, 2, 1, 11)
x2 = common_layers.apply_norm(
tf.constant(x1, dtype=tf.float32), "batch", depth=11, epsilon=1e-6)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(x2)
self.assertEqual(actual.shape, (5, 2, 1, 11))
@test_utils.run_in_graph_and_eager_modes()
def testApplyNormNone(self):
x1 = np.random.rand(5, 2, 1, 11)
x2 = common_layers.apply_norm(
tf.constant(x1, dtype=tf.float32), "none", depth=11, epsilon=1e-6)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(x2)
self.assertEqual(actual.shape, (5, 2, 1, 11))
self.assertAllClose(actual, x1, atol=1e-03)
@test_utils.run_in_graph_mode_only()
def testDenseWithLayerCollection(self):
with tf.variable_scope("test_layer_collection"):
x1 = tf.zeros([3, 4], tf.float32)
layer_collection = kfac.LayerCollection()
common_layers.dense(
x1, units=10, layer_collection=layer_collection, name="y1")
self.assertLen(layer_collection.get_blocks(), 1)
# 3D inputs.
x2 = tf.zeros([3, 4, 5], tf.float32)
common_layers.dense(
x2, units=10, layer_collection=layer_collection, name="y2")
self.assertLen(layer_collection.get_blocks(), 2)
def testGlobalPool1d(self):
x1 = np.random.rand(5, 4, 11)
no_mask = np.ones((5, 4))
full_mask = np.zeros((5, 4))
x1_ = tf.Variable(x1, dtype=tf.float32)
no_mask_ = tf.Variable(no_mask, dtype=tf.float32)
full_mask_ = tf.Variable(full_mask, dtype=tf.float32)
none_mask_max = common_layers.global_pool_1d(x1_)
no_mask_max = common_layers.global_pool_1d(x1_, mask=no_mask_)
result1 = tf.reduce_sum(none_mask_max - no_mask_max)
full_mask_max = common_layers.global_pool_1d(x1_, mask=full_mask_)
result2 = tf.reduce_sum(full_mask_max)
none_mask_avr = common_layers.global_pool_1d(x1_, "AVR")
no_mask_avr = common_layers.global_pool_1d(x1_, "AVR", no_mask_)
result3 = tf.reduce_sum(none_mask_avr - no_mask_avr)
full_mask_avr = common_layers.global_pool_1d(x1_, "AVR", full_mask_)
result4 = tf.reduce_sum(full_mask_avr)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate([result1, result2, result3, result4])
self.assertAllEqual(actual[:3], [0.0, 0.0, 0.0])
def testLinearSetLayer(self):
x1 = np.random.rand(5, 4, 11)
cont = np.random.rand(5, 13)
x1_ = tf.Variable(x1, dtype=tf.float32)
cont_ = tf.Variable(cont, dtype=tf.float32)
simple_ff = common_layers.linear_set_layer(32, x1_)
cont_ff = common_layers.linear_set_layer(32, x1_, context=cont_)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate([simple_ff, cont_ff])
self.assertEqual(actual[0].shape, (5, 4, 32))
self.assertEqual(actual[1].shape, (5, 4, 32))
def testRavanbakhshSetLayer(self):
x1 = np.random.rand(5, 4, 11)
x1_ = tf.Variable(x1, dtype=tf.float32)
layer = common_layers.ravanbakhsh_set_layer(32, x1_)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(layer)
self.assertEqual(actual.shape, (5, 4, 32))
@test_utils.run_in_graph_and_eager_modes()
def testTopKthIterativeShape(self):
x = np.random.rand(5, 2, 1, 12)
y = common_layers.top_kth_iterative(tf.constant(x, dtype=tf.float32), 3)
actual = self.evaluate(y)
self.assertEqual(actual.shape, (5, 2, 1, 1))
@test_utils.run_in_graph_and_eager_modes()
def testTopKthIterativeValue(self):
x = [1.0, 2.0, 3.0, 4.0]
y = common_layers.top_kth_iterative(tf.constant(x, dtype=tf.float32), 3)
actual = self.evaluate(y)
self.assertEqual(int(actual[0]), 2.0)
@test_utils.run_in_graph_and_eager_modes()
def testBReLU(self):
x = np.random.rand(5, 2, 1, 12)
y = common_layers.brelu(tf.constant(x, dtype=tf.float32))
actual = self.evaluate(y)
self.assertEqual(actual.shape, (5, 2, 1, 12))
@test_utils.run_in_graph_and_eager_modes()
def testBELU(self):
x = np.random.rand(5, 2, 1, 12)
y = common_layers.belu(tf.constant(x, dtype=tf.float32))
actual = self.evaluate(y)
self.assertEqual(actual.shape, (5, 2, 1, 12))
@test_utils.run_in_graph_and_eager_modes()
def testNAC(self):
x = np.random.rand(5, 2, 1, 12)
y = common_layers.nac(tf.constant(x, dtype=tf.float32), 14)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(y)
self.assertEqual(actual.shape, (5, 2, 1, 14))
@test_utils.run_in_graph_and_eager_modes()
def testNALU(self):
x = np.random.rand(5, 2, 1, 12)
y = common_layers.nalu(tf.constant(x, dtype=tf.float32), 14)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(y)
self.assertEqual(actual.shape, (5, 2, 1, 14))
@test_utils.run_in_graph_and_eager_modes()
def testNALUzeros(self):
x = np.random.rand(5, 2, 1, 12)
y = common_layers.nalu(tf.zeros_like(x, dtype=tf.float32), 14)
self.evaluate(tf.global_variables_initializer())
actual = self.evaluate(y)
self.assertTrue(np.all(np.isfinite(actual)))
self.assertEqual(actual.shape, (5, 2, 1, 14))
@test_utils.run_in_graph_mode_only
def testPaddingCrossEntropyFactored(self):
vocab_size = 19
rows = 5
cols = 4
depth = 11
label_smoothing = 0.1
features = np.random.rand(rows, cols, depth)
weights = np.random.rand(vocab_size, depth)
labels = np.random.randint(0, vocab_size - 1, size=(rows, cols))
with self.test_session() as session:
features = tf.to_float(features)
weights = tf.to_float(weights)
labels = tf.to_int32(labels)
logits = tf.matmul(
tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True)
logits = tf.reshape(logits, [rows, cols, vocab_size])
loss_num, loss_den = common_layers.padded_cross_entropy(
logits, labels, label_smoothing=label_smoothing, reduce_sum=False)
factored_logits = common_layers.FactoredTensor(features, weights)
loss_num_f, loss_den_f = common_layers.padded_cross_entropy_factored(
factored_logits,
labels=labels,
label_smoothing=label_smoothing,
reduce_sum=False)
num, den, num_f, den_f = session.run(
[loss_num, loss_den, loss_num_f, loss_den_f])
self.assertEqual(num.shape, (rows, cols))
self.assertEqual(den.shape, (rows, cols))
self.assertEqual(num_f.shape, (rows, cols))
self.assertEqual(den_f.shape, (rows, cols))
self.assertAllClose(num, num_f)
self.assertAllClose(den, den_f)
@test_utils.run_in_graph_mode_only
def testPaddingCrossEntropyFactoredGrad(self):
vocab_size = 19
rows = 5
cols = 4
depth = 11
label_smoothing = 0.1
features = np.random.rand(rows, cols, depth)
weights = np.random.rand(vocab_size, depth)
labels = np.random.randint(0, vocab_size - 1, size=(rows, cols))
with self.test_session() as session:
features = tf.to_float(features)
weights = tf.to_float(weights)
labels = tf.to_int32(labels)
logits = tf.matmul(
tf.reshape(features, [rows * cols, depth]), weights, transpose_b=True)
logits = tf.reshape(logits, [rows, cols, vocab_size])
loss_num, loss_den = common_layers.padded_cross_entropy(
logits, labels, label_smoothing=label_smoothing, reduce_sum=False)
factored_logits = common_layers.FactoredTensor(features, weights)
loss_num_factored, loss_den_factored = (
common_layers.padded_cross_entropy_factored(
factored_logits,
labels=labels,
label_smoothing=label_smoothing,
reduce_sum=False))
df, dw = tf.gradients(ys=[loss_num, loss_den], xs=[features, weights])
df_factored, dw_factored = tf.gradients(
ys=[loss_num_factored, loss_den_factored], xs=[features, weights])
actual_df, actual_dw, actual_df_factored, actual_dw_factored = (
session.run([df, dw, df_factored, dw_factored]))
self.assertEqual(actual_df.shape, (rows, cols, depth))
self.assertEqual(actual_dw.shape, (vocab_size, depth))
self.assertEqual(actual_df_factored.shape, (rows, cols, depth))
self.assertEqual(actual_dw_factored.shape, (vocab_size, depth))
self.assertAllClose(actual_df, actual_df_factored)
self.assertAllClose(actual_dw, actual_dw_factored)
@parameterized.parameters(
(2, 4, 4, 5, True),
(2, 4, 4, 5, False),
(1, 16, 16, 1, True),
(1, 16, 16, 1, False),
)
def testDmlLoss(self, batch, height, width, num_mixtures, reduce_sum):
channels = 3
pred = tf.random_normal([batch, height, width, num_mixtures * 10])
labels = tf.random_uniform([batch, height, width, channels],
minval=0, maxval=256, dtype=tf.int32)
actual_loss_num, actual_loss_den = common_layers.dml_loss(
pred=pred, labels=labels, reduce_sum=reduce_sum)
actual_loss = actual_loss_num / actual_loss_den
real_labels = common_layers.convert_rgb_to_symmetric_real(labels)
expected_loss = common_layers.discretized_mix_logistic_loss(
pred=pred, labels=real_labels) / channels
if reduce_sum:
expected_loss = tf.reduce_mean(expected_loss)
actual_loss_val, expected_loss_val = self.evaluate(
[actual_loss, expected_loss])
self.assertAllClose(actual_loss_val, expected_loss_val)
@test_utils.run_in_graph_and_eager_modes()
def testWeightsMultiProblemAll(self):
labels = tf.constant(np.array([[12, 15, 1, 20, 100],
[67, 1, 34, 45, 124],
[78, 2, 34, 18, 29],
[78, 123, 55, 1, 33],
[1, 18, 22, 36, 59]]), dtype=tf.int32)
taskid = 1
expected_mask = np.array([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]])
actual_mask = common_layers.weights_multi_problem_all(labels, taskid)
actual_mask_eval = self.evaluate(actual_mask)
self.assertAllClose(expected_mask, actual_mask_eval)
@test_utils.run_in_graph_and_eager_modes()
def testWeightsMultiProblem(self):
labels = tf.constant(np.array([[12, 15, 1, 20, 100],
[67, 1, 34, 45, 124],
[78, 2, 34, 18, 29],
[78, 123, 55, 1, 33],
[1, 18, 22, 36, 59]]), dtype=tf.int32)
taskid = 1
expected_mask = np.array([[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 1, 1, 1, 1]])
actual_mask = common_layers.weights_multi_problem(labels, taskid)
actual_mask_eval = self.evaluate(actual_mask)
self.assertAllClose(expected_mask, actual_mask_eval)
@test_utils.run_in_graph_and_eager_modes()
def testDiscretizedMixLogisticLoss(self):
batch = 2
height = 4
width = 4
channels = 3
num_mixtures = 5
logits = tf.concat( # assign all probability mass to first component
[tf.ones([batch, height, width, 1]) * 1e8,
tf.zeros([batch, height, width, num_mixtures - 1])],
axis=-1)
locs = tf.random_uniform([batch, height, width, num_mixtures * 3],
minval=-.9, maxval=.9)
log_scales = tf.random_uniform([batch, height, width, num_mixtures * 3],
minval=-1., maxval=1.)
coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3]))
pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1)
# Test labels that don't satisfy edge cases where 8-bit value is 0 or 255.
labels = tf.random_uniform([batch, height, width, channels],
minval=-.9, maxval=.9)
locs_0 = locs[..., :3]
log_scales_0 = log_scales[..., :3]
centered_labels = labels - locs_0
inv_stdv = tf.exp(-log_scales_0)
plus_in = inv_stdv * (centered_labels + 1. / 255.)
min_in = inv_stdv * (centered_labels - 1. / 255.)
cdf_plus = tf.nn.sigmoid(plus_in)
cdf_min = tf.nn.sigmoid(min_in)
expected_loss = -tf.reduce_sum(tf.log(cdf_plus - cdf_min), axis=-1)
actual_loss = common_layers.discretized_mix_logistic_loss(
pred=pred, labels=labels)
actual_loss_val, expected_loss_val = self.evaluate(
[actual_loss, expected_loss])
self.assertAllClose(actual_loss_val, expected_loss_val, rtol=1e-5)
@test_utils.run_in_graph_and_eager_modes()
def testSampleFromDiscretizedMixLogistic(self):
batch = 2
height = 4
width = 4
num_mixtures = 5
seed = 42
logits = tf.concat( # assign all probability mass to first component
[tf.ones([batch, height, width, 1]) * 1e8,
tf.zeros([batch, height, width, num_mixtures - 1])],
axis=-1)
locs = tf.random_uniform([batch, height, width, num_mixtures * 3],
minval=-.9, maxval=.9)
log_scales = tf.ones([batch, height, width, num_mixtures * 3]) * -1e8
coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3]))
pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1)
locs_0 = locs[..., :3]
expected_sample = tf.clip_by_value(locs_0, -1., 1.)
actual_sample = common_layers.sample_from_discretized_mix_logistic(
pred, seed=seed)
actual_sample_val, expected_sample_val = self.evaluate(
[actual_sample, expected_sample])
# Use a low tolerance: samples numerically differ, as the actual
# implementation clips log-scales so they always contribute to sampling.
self.assertAllClose(actual_sample_val, expected_sample_val, atol=1e-2)
@test_utils.run_in_graph_and_eager_modes()
def testFactoredTensorImplicitConversion(self):
a = np.random.rand(3, 4, 5)
b = np.random.rand(6, 5)
c = np.random.rand(3, 4, 6)
# a factored representation of a Tensor of shape (3, 4, 6)
factored = common_layers.FactoredTensor(tf.to_float(a), tf.to_float(b))
# implicitly converts factored to a Tensor (performing the matmul)
d = factored + tf.to_float(c)
out = self.evaluate(d)
self.assertEqual(out.shape, (3, 4, 6))
@test_utils.run_in_graph_mode_only()
def testConvHiddenReluMemoryEfficient(self):
batch = 3
length = 23
io_size = 16
filter_size = 7
x = np.random.rand(batch, length, io_size)
dy = np.random.rand(batch, length, io_size)
with self.test_session() as session:
x = tf.to_float(x)
dy = tf.to_float(dy)
f1 = tf.get_variable("f1", [1, io_size, filter_size])
f2 = tf.get_variable("f2", [1, filter_size, io_size])
norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
y = common_layers.conv_hidden_relu_memory_efficient(
x, filter_size, forget=False,
test_vars=(f1, f2, norm_scale, norm_bias))
y_forget = common_layers.conv_hidden_relu_memory_efficient(
x, filter_size, forget=True,
test_vars=(f1, f2, norm_scale, norm_bias))
dx, df1, df2, dnorm_scale, dnorm_bias = tf.gradients(
ys=[y], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
ys=[y_forget], xs=[x, f1, f2, norm_scale, norm_bias], grad_ys=[dy])
session.run(tf.global_variables_initializer())
(y, y_forget,
dx, df1, df2, dnorm_scale, dnorm_bias,
dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f) = session.run(
[y, y_forget,
dx, df1, df2, dnorm_scale, dnorm_bias,
dx_f, df1_f, df2_f, dnorm_scale_f, dnorm_bias_f])
self.assertAllClose(y, y_forget)
self.assertAllClose(df2, df2_f, rtol=2e-6, atol=2e-6)
self.assertAllClose(df1, df1_f, rtol=2e-6, atol=2e-6)
self.assertAllClose(dnorm_scale, dnorm_scale_f)
self.assertAllClose(dnorm_bias, dnorm_bias_f)
self.assertAllClose(dx, dx_f)
@test_utils.run_in_graph_and_eager_modes()
def testTopk(self):
batch_size = 3
seq_len = 5
vocab_size = 7
top_k = [3, 2, -1]
logits = np.random.rand(batch_size, seq_len, 1, 1, vocab_size) + 0.001
topk_logits = common_layers._select_top_k(logits, top_k)
self.evaluate(tf.global_variables_initializer())
topk_logits = self.evaluate(topk_logits)
for i, k in enumerate(top_k):
for j in range(seq_len):
self.assertEqual((topk_logits[i, j, 0, 0, :] > -1e6).sum(),
k if k != -1 else vocab_size)
@test_utils.run_in_graph_and_eager_modes()
def testSampleTemperaturePerExample(self):
batch_size = 3
seq_len = 5
vocab_size = 7
logits = np.random.randn(batch_size, seq_len, 1, 1, vocab_size)
temperature = np.random.rand(batch_size)
out = common_layers.sample_temperature_per_example(logits, temperature, -1)
self.assertAllEqual(
self.evaluate(tf.shape(out)), [batch_size, seq_len, 1, 1])
@test_utils.run_in_graph_and_eager_modes()
def testSampleTemperaturePerExampleWithTopK(self):
batch_size = 3
seq_len = 5
vocab_size = 7
logits = np.random.randn(batch_size, seq_len, 1, 1, vocab_size)
temperature = np.random.rand(batch_size)
top_k = np.array([3, -1, 4], dtype=np.int32)
out = common_layers.sample_temperature_per_example(logits, temperature,
top_k)
self.assertAllEqual(
self.evaluate(tf.shape(out)), [batch_size, seq_len, 1, 1])
@test_utils.run_in_graph_and_eager_modes()
def testSampleTemperaturePerExampleWithTopK2(self):
batch_size = 3
vocab_size = 7
logits = np.random.randn(batch_size, vocab_size)
temperature = np.random.rand(batch_size)
top_k = np.array([3, -1, 4], dtype=np.int32)
out = common_layers.sample_temperature_per_example(logits, temperature,
top_k)
self.assertAllEqual(self.evaluate(tf.shape(out)), [batch_size])
@test_utils.run_in_graph_mode_only()
def testSampleTemperaturePerExampleDynamicBatchSize(self):
batch_size = None
vocab_size = 7
logits = tf.placeholder(tf.float32, shape=(batch_size, vocab_size))
temperature = tf.placeholder(tf.float32, shape=(batch_size, 1))
sampling_keep_top_k = tf.placeholder(tf.int32, shape=(batch_size, 1))
out = common_layers.sample_temperature_per_example(logits, temperature,
sampling_keep_top_k)
self.assertAllEqual(out.shape.as_list(), [batch_size])
@test_utils.run_in_graph_and_eager_modes()
def testCycleGANUpsampleNnUpsampleConv(self):
batch = 8
height = 32
width = 32
num_channels = 3
output_filters = 10
stride = [2, 3] # we want height to be x2 and width to be x3
random_input = np.random.rand(batch, height, width, num_channels).astype(
np.float32)
# nn_upsample_conv gives exactly the shapes we'd expect.
upsampled_output = common_layers.cyclegan_upsample(
random_input, output_filters, stride, "nn_upsample_conv")
upsampled_output_shape = tf.shape(upsampled_output)
self.evaluate(tf.global_variables_initializer())
self.assertAllEqual(
[batch, height * stride[0], width * stride[1], output_filters],
self.evaluate(upsampled_output_shape))
@test_utils.run_in_graph_and_eager_modes()
def testCycleGANUpsampleBilinearUpsampleConv(self):
batch = 8
height = 32
width = 32
num_channels = 3
output_filters = 10
stride = [2, 3] # we want height to be x2 and width to be x3
random_input = np.random.rand(batch, height, width, num_channels).astype(
np.float32)
# bilinear_upsample_conv gives exactly the shapes we'd expect.
upsampled_output = common_layers.cyclegan_upsample(
random_input, output_filters, stride, "bilinear_upsample_conv")
upsampled_output_shape = tf.shape(upsampled_output)
self.evaluate(tf.global_variables_initializer())
self.assertAllEqual(
[batch, height * stride[0], width * stride[1], output_filters],
self.evaluate(upsampled_output_shape))
@test_utils.run_in_graph_and_eager_modes()
def testCycleGANUpsampleConv2dTranspose(self):
batch = 8
height = 32
width = 32
num_channels = 3
output_filters = 10
stride = [2, 3] # we want height to be x2 and width to be x3
random_input = tf.convert_to_tensor(
np.random.rand(batch, height, width, num_channels), dtype=tf.float32)
# conv2d_transpose is a little tricky.
# height_new = (height_old - 1) * stride + kernel - 2*padding - correction
# here kernel = 3, padding = 0, correction = 1
upsampled_height = (height - 1) * stride[0] + 3 - 2*0 - 1
upsampled_width = (width - 1) * stride[1] + 3 - 2*0 - 1
upsampled_output = common_layers.cyclegan_upsample(random_input,
output_filters, stride,
"conv2d_transpose")
upsampled_output_shape = tf.shape(upsampled_output)
self.evaluate(tf.global_variables_initializer())
self.assertAllEqual(
[batch, upsampled_height, upsampled_width, output_filters],
self.evaluate(upsampled_output_shape))
def testSpectralNorm(self):
# Test that after 20 calls to apply_spectral_norm, the spectral
# norm of the normalized matrix is close to 1.0
with tf.Graph().as_default():
weights = tf.get_variable("w", dtype=tf.float32, shape=[2, 3, 50, 100])
weights = tf.multiply(weights, 10.0)
normed_weight, assign_op = common_layers.apply_spectral_norm(weights)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(20):
sess.run(assign_op)
normed_weight, assign_op = common_layers.apply_spectral_norm(
weights)
normed_weight = sess.run(normed_weight).reshape(-1, 100)
_, s, _ = np.linalg.svd(normed_weight)
self.assertTrue(np.allclose(s[0], 1.0, rtol=0.1))
class FnWithCustomGradTest(tf.test.TestCase):
@test_utils.run_in_graph_mode_only()
def testCorrectness(self):
w = tf.random_uniform([6, 10])
def fn(a, b, c):
return tf.layers.dense(
a,
10,
use_bias=False,
kernel_initializer=lambda shape, dtype, partition_info: w
) + tf.matmul(b, c)
def grad_fn(inputs, variables, outputs, grad_outputs):
outputs = outputs[0]
grad_outputs = grad_outputs[0]
grad_inputs = tf.gradients(outputs, inputs, grad_ys=grad_outputs)
grad_vars = tf.gradients(outputs, variables, grad_ys=grad_outputs)
return grad_inputs, grad_vars
custom_fn = common_layers.fn_with_custom_grad(grad_fn)(fn)
a = tf.random_uniform([11, 6])
b = tf.random_uniform([11, 7])
c = tf.random_uniform([7, 10])
out = fn(a, b, c)
custom_out = custom_fn(a, b, c)
self.assertEqual(out.get_shape().as_list(),
custom_out.get_shape().as_list())
loss = tf.reduce_mean(out)
custom_loss = tf.reduce_mean(custom_out)
grads = tf.gradients(loss, [a, b, c] + [tf.trainable_variables()[0]])
custom_grads = tf.gradients(custom_loss,
[a, b, c] + [tf.trainable_variables()[1]])
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out_val, custom_out_val, grads_val, custom_grads_val = sess.run(
[out, custom_out, grads, custom_grads])
self.assertAllClose(out_val, custom_out_val)
for g1, g2 in zip(grads_val, custom_grads_val):
self.assertAllClose(g1, g2)
@test_utils.run_in_graph_mode_only()
def testCustomGrad(self):
def fn(a, b, c):
return tf.layers.dense(a, 10, use_bias=False) + tf.matmul(b, c)
def grad_fn(inputs, variables, unused_outputs, unused_grad_outputs):
grad_inputs = [tf.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)]
grad_vars = [
tf.ones_like(t) * (i + len(inputs) + 1.)
for i, t in enumerate(variables)
]
return grad_inputs, grad_vars
a = tf.random_uniform([11, 6])
b = tf.random_uniform([11, 7])
c = tf.random_uniform([7, 10])
w = tf.random_uniform([6, 10])
out = common_layers.fn_with_custom_grad(grad_fn)(fn)(a, b, c)
loss = tf.reduce_mean(out)
grads = tf.gradients(loss, [a, b, c, tf.trainable_variables()[0]])
expected_grads = [
tf.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w])
]
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
g_val, eg_val = sess.run([grads, expected_grads])
for g1, g2 in zip(g_val, eg_val):
self.assertAllClose(g1, g2)
class RecomputeTest(tf.test.TestCase):
@test_utils.run_in_graph_mode_only()
def testRecompute(self):
def layer(x, name=None):
with tf.variable_scope(name, default_name="layer"):
x = common_layers.layer_norm(x)
x = tf.layers.conv1d(
x,
10,
1,
use_bias=False,
kernel_initializer=tf.constant_initializer(42.42))
x = tf.nn.relu(x)
return x
def fn(x):
out = x
for _ in range(3):
out = layer(out)
return out
@common_layers.recompute_grad
def fn_recompute(x):
return fn(x)
x = tf.random_uniform((3, 1, 3))
recompute_vars = None
with tf.variable_scope("recompute") as vs:
out1 = tf.reduce_sum(fn_recompute(x))
recompute_vars = vs.trainable_variables()
reg_vars = None
with tf.variable_scope("regular") as vs:
out2 = tf.reduce_sum(fn(x))
reg_vars = vs.trainable_variables()
grad1 = tf.gradients(out1, recompute_vars)
grad2 = tf.gradients(out2, reg_vars)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
outs = sess.run([out1, out2, grad1, grad2])
self.assertAllClose(outs[0], outs[1])
for g1, g2 in zip(outs[2], outs[3]):
self.assertAllClose(g1, g2)
class WeightNormTest(tf.test.TestCase):
def testInputSpec(self):
"""Test that WeighNorm does not overspecify the input_spec."""
conv = common_layers.WeightNorm(
tf.keras.layers.Conv1D(filters=8, kernel_size=3))
# Call with one batch size:
conv(tf.zeros([1, 16, 2]))
# Should allow call with another batch size.
conv(tf.zeros([2, 16, 2]))
# Input spec does detect incorrect input feature dim.
with self.assertRaises(ValueError):
conv(tf.zeros([2, 16, 3]))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/common_video.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for video."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import contrib
import tensorflow.compat.v1 as tf
from tensorflow.python.ops import summary_op_util # pylint: disable=g-direct-tensorflow-import
# After tf-nightly 1.14.1.dev20190314 summary_op_util.skip_summary was extracted
# out to the distribute module.
try:
from tensorflow.python.distribute import summary_op_util as distribute_summary_op_util # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
except ImportError:
distribute_summary_op_util = summary_op_util
tfl = common_layers.layers()
def swap_time_and_batch_axes(inputs):
"""Swaps time and batch axis (the first two axis)."""
transposed_axes = tf.concat([[1, 0], tf.range(2, tf.rank(inputs))], axis=0)
return tf.transpose(inputs, transposed_axes)
def encode_to_shape(inputs, shape, scope):
"""Encode the given tensor to given image shape."""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
w, h = shape[1], shape[2]
x = inputs
x = tfl.flatten(x)
x = tfl.dense(x, w * h, activation=None, name="enc_dense")
x = tf.reshape(x, (-1, w, h, 1))
return x
def decode_to_shape(inputs, shape, scope):
"""Encode the given tensor to given image shape."""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
x = inputs
x = tfl.flatten(x)
x = tfl.dense(x, shape[2], activation=None, name="dec_dense")
x = tf.expand_dims(x, axis=1)
return x
def basic_lstm(inputs, state, num_units, name=None):
"""Basic LSTM."""
input_shape = common_layers.shape_list(inputs)
# reuse parameters across time-steps.
cell = tf.nn.rnn_cell.BasicLSTMCell(
num_units, name=name, reuse=tf.AUTO_REUSE)
if state is None:
state = cell.zero_state(input_shape[0], tf.float32)
outputs, new_state = cell(inputs, state)
return outputs, new_state
def lstm_cell(inputs,
state,
num_units,
use_peepholes=False,
cell_clip=0.0,
initializer=None,
num_proj=None,
num_unit_shards=None,
num_proj_shards=None,
reuse=None,
name=None):
"""Full LSTM cell."""
input_shape = common_layers.shape_list(inputs)
cell = tf.nn.rnn_cell.LSTMCell(num_units,
use_peepholes=use_peepholes,
cell_clip=cell_clip,
initializer=initializer,
num_proj=num_proj,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
reuse=reuse,
name=name,
state_is_tuple=False)
if state is None:
state = cell.zero_state(input_shape[0], tf.float32)
outputs, new_state = cell(inputs, state)
return outputs, new_state
def conv_lstm_2d(inputs, state, output_channels,
kernel_size=5, name=None, spatial_dims=None):
"""2D Convolutional LSTM."""
input_shape = common_layers.shape_list(inputs)
batch_size, input_channels = input_shape[0], input_shape[-1]
if spatial_dims is None:
input_shape = input_shape[1:]
else:
input_shape = spatial_dims + [input_channels]
cell = contrib.rnn().ConvLSTMCell(
2, input_shape, output_channels, [kernel_size, kernel_size], name=name)
if state is None:
state = cell.zero_state(batch_size, tf.float32)
outputs, new_state = cell(inputs, state)
return outputs, new_state
def scheduled_sample_count(ground_truth_x,
generated_x,
batch_size,
scheduled_sample_var):
"""Sample batch with specified mix of groundtruth and generated data points.
Args:
ground_truth_x: tensor of ground-truth data points.
generated_x: tensor of generated data points.
batch_size: batch size
scheduled_sample_var: number of ground-truth examples to include in batch.
Returns:
New batch with num_ground_truth sampled from ground_truth_x and the rest
from generated_x.
"""
num_ground_truth = scheduled_sample_var
idx = tf.random_shuffle(tf.range(batch_size))
ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size))
ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
generated_examps = tf.gather(generated_x, generated_idx)
output = tf.dynamic_stitch([ground_truth_idx, generated_idx],
[ground_truth_examps, generated_examps])
# if batch size is known set it.
if isinstance(batch_size, int):
output.set_shape([batch_size] + common_layers.shape_list(output)[1:])
return output
def inject_additional_input(layer, inputs, name, mode="concat"):
"""Injects the additional input into the layer.
Args:
layer: layer that the input should be injected to.
inputs: inputs to be injected.
name: TF scope name.
mode: how the infor should be added to the layer:
"concat" concats as additional channels.
"multiplicative" broadcasts inputs and multiply them to the channels.
"multi_additive" broadcasts inputs and multiply and add to the channels.
Returns:
updated layer.
Raises:
ValueError: in case of unknown mode.
"""
layer_shape = common_layers.shape_list(layer)
input_shape = common_layers.shape_list(inputs)
zeros_mask = tf.zeros(layer_shape, dtype=tf.float32)
if mode == "concat":
emb = encode_to_shape(inputs, layer_shape, name)
layer = tf.concat(values=[layer, emb], axis=-1)
elif mode == "multiplicative":
filters = layer_shape[-1]
input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]])
input_mask = tf.layers.dense(input_reshaped, filters, name=name)
input_broad = input_mask + zeros_mask
layer *= input_broad
elif mode == "multi_additive":
filters = layer_shape[-1]
input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]])
input_mul = tf.layers.dense(input_reshaped, filters, name=name + "_mul")
layer *= tf.nn.sigmoid(input_mul)
input_add = tf.layers.dense(input_reshaped, filters, name=name + "_add")
layer += input_add
else:
raise ValueError("Unknown injection mode: %s" % mode)
return layer
def scheduled_sample_prob(ground_truth_x,
generated_x,
batch_size,
scheduled_sample_var):
"""Probability based scheduled sampling.
Args:
ground_truth_x: tensor of ground-truth data points.
generated_x: tensor of generated data points.
batch_size: batch size
scheduled_sample_var: probability of choosing from ground_truth.
Returns:
New batch with randomly selected data points.
"""
probability_threshold = scheduled_sample_var
probability_of_generated = tf.random_uniform([batch_size])
return tf.where(probability_of_generated > probability_threshold,
generated_x, ground_truth_x)
def dna_transformation(prev_image, dna_input, dna_kernel_size, relu_shift):
"""Apply dynamic neural advection to previous image.
Args:
prev_image: previous image to be transformed.
dna_input: hidden lyaer to be used for computing DNA transformation.
dna_kernel_size: dna kernel size.
relu_shift: shift for ReLU function.
Returns:
List of images transformed by the predicted CDNA kernels.
"""
# Construct translated images.
prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
image_height = int(prev_image.get_shape()[1])
image_width = int(prev_image.get_shape()[2])
inputs = []
for xkern in range(dna_kernel_size):
for ykern in range(dna_kernel_size):
inputs.append(
tf.expand_dims(
tf.slice(prev_image_pad, [0, xkern, ykern, 0],
[-1, image_height, image_width, -1]), [3]))
inputs = tf.concat(axis=3, values=inputs)
# Normalize channels to 1.
kernel = tf.nn.relu(dna_input - relu_shift) + relu_shift
kernel = tf.expand_dims(
kernel / tf.reduce_sum(kernel, [3], keep_dims=True), [4])
return tf.reduce_sum(kernel * inputs, [3], keep_dims=False)
def cdna_transformation(prev_image, cdna_input, num_masks, color_channels,
dna_kernel_size, relu_shift):
"""Apply convolutional dynamic neural advection to previous image.
Args:
prev_image: previous image to be transformed.
cdna_input: hidden lyaer to be used for computing CDNA kernels.
num_masks: number of masks and hence the number of CDNA transformations.
color_channels: the number of color channels in the images.
dna_kernel_size: dna kernel size.
relu_shift: shift for ReLU function.
Returns:
List of images transformed by the predicted CDNA kernels.
"""
batch_size = tf.shape(cdna_input)[0]
height = int(prev_image.get_shape()[1])
width = int(prev_image.get_shape()[2])
# Predict kernels using linear function of last hidden layer.
cdna_kerns = tfl.dense(
cdna_input, dna_kernel_size * dna_kernel_size * num_masks,
name="cdna_params",
activation=None)
# Reshape and normalize.
cdna_kerns = tf.reshape(
cdna_kerns, [batch_size, dna_kernel_size, dna_kernel_size, 1, num_masks])
cdna_kerns = (tf.nn.relu(cdna_kerns - relu_shift) + relu_shift)
norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
cdna_kerns /= norm_factor
# Treat the color channel dimension as the batch dimension since the same
# transformation is applied to each color channel.
# Treat the batch dimension as the channel dimension so that
# depthwise_conv2d can apply a different transformation to each sample.
cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
cdna_kerns = tf.reshape(
cdna_kerns, [dna_kernel_size, dna_kernel_size, batch_size, num_masks])
# Swap the batch and channel dimensions.
prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
# Transform image.
transformed = tf.nn.depthwise_conv2d(
prev_image, cdna_kerns, [1, 1, 1, 1], "SAME")
# Transpose the dimensions to where they belong.
transformed = tf.reshape(
transformed, [color_channels, height, width, batch_size, num_masks])
transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
transformed = tf.unstack(transformed, axis=-1)
return transformed
def vgg_layer(inputs,
nout,
kernel_size=3,
activation=tf.nn.leaky_relu,
padding="SAME",
is_training=True,
has_batchnorm=False,
scope=None):
"""A layer of VGG network with batch norm.
Args:
inputs: image tensor
nout: number of output channels
kernel_size: size of the kernel
activation: activation function
padding: padding of the image
is_training: whether it is training mode or not
has_batchnorm: whether batchnorm is applied or not
scope: variable scope of the op
Returns:
net: output of layer
"""
with tf.variable_scope(scope):
net = tfl.conv2d(inputs, nout, kernel_size=kernel_size, padding=padding,
activation=None, name="conv")
if has_batchnorm:
net = tfl.batch_normalization(net, training=is_training, name="bn")
net = activation(net)
return net
def tile_and_concat(image, latent, concat_latent=True):
"""Tile latent and concatenate to image across depth.
Args:
image: 4-D Tensor, (batch_size X height X width X channels)
latent: 2-D Tensor, (batch_size X latent_dims)
concat_latent: If set to False, the image is returned as is.
Returns:
concat_latent: 4-D Tensor, (batch_size X height X width X channels+1)
latent tiled and concatenated to the image across the channels.
"""
if not concat_latent:
return image
image_shape = common_layers.shape_list(image)
latent_shape = common_layers.shape_list(latent)
height, width = image_shape[1], image_shape[2]
latent_dims = latent_shape[1]
height_multiples = height // latent_dims
pad = height - (height_multiples * latent_dims)
latent = tf.reshape(latent, (-1, latent_dims, 1, 1))
latent = tf.tile(latent, (1, height_multiples, width, 1))
latent = tf.pad(latent, [[0, 0], [pad // 2, pad // 2], [0, 0], [0, 0]])
return tf.concat([image, latent], axis=-1)
def _encode_gif(images, fps):
"""Encodes numpy images into gif string.
Args:
images: A 4-D `uint8` `np.array` (or a list of 3-D images) of shape
`[time, height, width, channels]` where `channels` is 1 or 3.
fps: frames per second of the animation
Returns:
The encoded gif string.
Raises:
IOError: If the ffmpeg command returns an error.
"""
writer = WholeVideoWriter(fps)
writer.write_multi(images)
return writer.finish()
def ffmpeg_works():
"""Tries to encode images with ffmpeg to check if it works."""
images = np.zeros((2, 32, 32, 3), dtype=np.uint8)
try:
_encode_gif(images, 2)
return True
except (IOError, OSError):
return False
def py_gif_summary(tag, images, max_outputs, fps, return_summary_value=False):
"""Outputs a `Summary` protocol buffer with gif animations.
Args:
tag: Name of the summary.
images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, width,
channels]` where `channels` is 1 or 3.
max_outputs: Max number of batch elements to generate gifs for.
fps: frames per second of the animation.
return_summary_value: If set to True, return a list of tf.Summary.Value
objects in addition to the protocol buffer.
Returns:
The serialized `Summary` protocol buffer.
Raises:
ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels.
"""
images = np.asarray(images)
if images.dtype != np.uint8:
raise ValueError("Tensor must have dtype uint8 for gif summary.")
if images.ndim != 5:
raise ValueError("Tensor must be 5-D for gif summary.")
batch_size, _, height, width, channels = images.shape
if channels not in (1, 3):
raise ValueError("Tensors must have 1 or 3 channels for gif summary.")
summ = tf.Summary()
all_summ_values = []
num_outputs = min(batch_size, max_outputs)
for i in range(num_outputs):
image_summ = tf.Summary.Image()
image_summ.height = height
image_summ.width = width
image_summ.colorspace = channels # 1: grayscale, 3: RGB
try:
image_summ.encoded_image_string = _encode_gif(images[i], fps)
except (IOError, OSError) as e:
tf.logging.warning(
"Unable to encode images to a gif string because either ffmpeg is "
"not installed or ffmpeg returned an error: %s. Falling back to an "
"image summary of the first frame in the sequence.", e)
try:
from PIL import Image # pylint: disable=g-import-not-at-top
import io # pylint: disable=g-import-not-at-top
with io.BytesIO() as output:
Image.fromarray(images[i][0]).save(output, "PNG")
image_summ.encoded_image_string = output.getvalue()
except ImportError as e:
tf.logging.warning(
"Gif summaries requires ffmpeg or PIL to be installed: %s", e)
image_summ.encoded_image_string = ""
if num_outputs == 1:
summ_tag = "{}/gif".format(tag)
else:
summ_tag = "{}/gif/{}".format(tag, i)
curr_summ_value = tf.Summary.Value(tag=summ_tag, image=image_summ)
all_summ_values.append(curr_summ_value)
summ.value.add(tag=summ_tag, image=image_summ)
summ_str = summ.SerializeToString()
if return_summary_value:
return all_summ_values, summ_str
return summ_str
def gif_summary(name, tensor, max_outputs=3, fps=10, collections=None,
family=None):
"""Outputs a `Summary` protocol buffer with gif animations.
Args:
name: Name of the summary.
tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width,
channels]` where `channels` is 1 or 3.
max_outputs: Max number of batch elements to generate gifs for.
fps: frames per second of the animation
collections: Optional list of tf.GraphKeys. The collections to add the
summary to. Defaults to [tf.GraphKeys.SUMMARIES]
family: Optional; if provided, used as the prefix of the summary tag name,
which controls the tab name used for display on Tensorboard.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
Raises:
ValueError: if the given tensor has the wrong shape.
"""
tensor = tf.convert_to_tensor(tensor)
if len(tensor.get_shape()) != 5:
raise ValueError("Assuming videos given as tensors in the format "
"[batch, time, height, width, channels] but got one "
"of shape: %s" % str(tensor.get_shape()))
tensor = tf.cast(tensor, tf.uint8)
if distribute_summary_op_util.skip_summary():
return tf.constant("")
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
val = tf.py_func(
py_gif_summary,
[tag, tensor, max_outputs, fps],
tf.string,
stateful=False,
name=scope)
summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES])
return val
def tinyify(array, tiny_mode, small_mode):
if tiny_mode:
return [1 for _ in array]
if small_mode:
return [max(x // 4, 1) for x in array]
return array
def get_gaussian_tensor(mean, log_var):
z = tf.random_normal(tf.shape(mean), 0, 1, dtype=tf.float32)
z = mean + tf.exp(log_var / 2.0) * z
return z
def conv_latent_tower(images, time_axis, latent_channels=1, min_logvar=-5,
is_training=False, random_latent=False,
tiny_mode=False, small_mode=False):
"""Builds convolutional latent tower for stochastic model.
At training time this tower generates a latent distribution (mean and std)
conditioned on the entire video. This latent variable will be fed to the
main tower as an extra variable to be used for future frames prediction.
At inference time, the tower is disabled and only returns latents sampled
from N(0,1).
If the multi_latent flag is on, a different latent for every timestep would
be generated.
Args:
images: tensor of ground truth image sequences
time_axis: the time axis in images tensor
latent_channels: number of latent channels
min_logvar: minimum value for log_var
is_training: whether or not it is training mode
random_latent: whether or not generate random latents
tiny_mode: whether or not it is tiny_mode. tiny_mode sets the number
of conv channels to 1 at each layer. useful for testing the
integration tests.
small_mode: whether or not it is small_mode. small mode is the same model
with less conv and lstm layers and also lower number of channels.
suitable for videos with less complexity and testing.
Returns:
latent_mean: predicted latent mean
latent_logvar: predicted latent log variance
"""
conv_size = tinyify([32, 64, 64], tiny_mode, small_mode)
with tf.variable_scope("latent", reuse=tf.AUTO_REUSE):
images = tf.to_float(images)
images = tf.unstack(images, axis=time_axis)
images = tf.concat(images, axis=3)
x = images
x = common_layers.make_even_size(x)
x = tfl.conv2d(x, conv_size[0], [3, 3], strides=(2, 2),
padding="SAME", activation=tf.nn.relu, name="latent_conv1")
x = contrib.layers().layer_norm(x)
if not small_mode:
x = tfl.conv2d(x, conv_size[1], [3, 3], strides=(2, 2),
padding="SAME", activation=tf.nn.relu, name="latent_conv2")
x = contrib.layers().layer_norm(x)
x = tfl.conv2d(x, conv_size[2], [3, 3], strides=(1, 1),
padding="SAME", activation=tf.nn.relu, name="latent_conv3")
x = contrib.layers().layer_norm(x)
nc = latent_channels
mean = tfl.conv2d(x, nc, [3, 3], strides=(2, 2),
padding="SAME", activation=None, name="latent_mean")
logv = tfl.conv2d(x, nc, [3, 3], strides=(2, 2),
padding="SAME", activation=tf.nn.relu, name="latent_std")
logvar = logv + min_logvar
# No latent tower at inference time, just standard gaussian.
if not is_training:
return tf.zeros_like(mean), tf.zeros_like(logvar)
# No latent in the first phase
ret_mean, ret_logvar = tf.cond(
random_latent,
lambda: (tf.zeros_like(mean), tf.zeros_like(logvar)),
lambda: (mean, logvar))
return ret_mean, ret_logvar
def beta_schedule(schedule, global_step, final_beta, decay_start, decay_end):
"""Get KL multiplier (beta) based on the schedule."""
if decay_start > decay_end:
raise ValueError("decay_end is smaller than decay_end.")
# Since some of the TF schedules do not support incrementing a value,
# in all of the schedules, we anneal the beta from final_beta to zero
# and then reverse it at the bottom.
if schedule == "constant":
decayed_value = 0.0
elif schedule == "linear":
decayed_value = tf.train.polynomial_decay(
learning_rate=final_beta,
global_step=global_step - decay_start,
decay_steps=decay_end - decay_start,
end_learning_rate=0.0)
elif schedule == "noisy_linear_cosine_decay":
decayed_value = tf.train.noisy_linear_cosine_decay(
learning_rate=final_beta,
global_step=global_step - decay_start,
decay_steps=decay_end - decay_start)
# TODO(mechcoder): Add log_annealing schedule.
else:
raise ValueError("Unknown beta schedule.")
increased_value = final_beta - decayed_value
increased_value = tf.maximum(0.0, increased_value)
beta = tf.case(
pred_fn_pairs={
tf.less(global_step, decay_start): lambda: 0.0,
tf.greater(global_step, decay_end): lambda: final_beta},
default=lambda: increased_value)
return beta
def extract_random_video_patch(videos, num_frames=-1):
"""For every video, extract a random consecutive patch of num_frames.
Args:
videos: 5-D Tensor, (NTHWC)
num_frames: Integer, if -1 then the entire video is returned.
Returns:
video_patch: 5-D Tensor, (NTHWC) with T = num_frames.
Raises:
ValueError: If num_frames is greater than the number of total frames in
the video.
"""
if num_frames == -1:
return videos
batch_size, num_total_frames, h, w, c = common_layers.shape_list(videos)
if num_total_frames < num_frames:
raise ValueError("Expected num_frames <= %d, got %d" %
(num_total_frames, num_frames))
# Randomly choose start_inds for each video.
frame_start = tf.random_uniform(
shape=(batch_size,), minval=0, maxval=num_total_frames - num_frames + 1,
dtype=tf.int32)
# [start[0], start[0] + 1, ... start[0] + num_frames - 1] + ...
# [start[batch_size-1], ... start[batch_size-1] + num_frames - 1]
range_inds = tf.expand_dims(tf.range(num_frames), axis=0)
frame_inds = range_inds + tf.expand_dims(frame_start, axis=1)
frame_inds = tf.reshape(frame_inds, [-1])
# [0]*num_frames + [1]*num_frames + ... [batch_size-1]*num_frames
batch_inds = tf.expand_dims(tf.range(batch_size), axis=1)
batch_inds = tf.tile(batch_inds, [1, num_frames])
batch_inds = tf.reshape(batch_inds, [-1])
gather_inds = tf.stack((batch_inds, frame_inds), axis=1)
video_patches = tf.gather_nd(videos, gather_inds)
return tf.reshape(video_patches, (batch_size, num_frames, h, w, c))
class VideoWriter(object):
"""Base helper class for writing videos."""
def write(self, frame, encoded_frame=None):
"""Writes a single video frame."""
raise NotImplementedError
def write_multi(self, frames, encoded_frames=None):
"""Writes multiple video frames."""
if encoded_frames is None:
# Infinite iterator.
encoded_frames = iter(lambda: None, 1)
for (frame, encoded_frame) in zip(frames, encoded_frames):
self.write(frame, encoded_frame)
def finish(self):
"""Finishes writing frames and returns output, if any.
Frees any resources acquired by the writer.
"""
pass
def save_to_disk(self, output):
"""Saves output to disk.
Args:
output: result of finish().
"""
raise NotImplementedError
def finish_to_disk(self):
"""Finishes writing frames and saves output to disk, if any."""
output = self.finish() # pylint: disable=assignment-from-no-return
if output is not None:
self.save_to_disk(output)
def __del__(self):
"""Frees any resources acquired by the writer."""
self.finish()
class WholeVideoWriter(VideoWriter):
"""Helper class for writing whole videos."""
def __init__(self, fps, output_path=None, file_format="gif"):
self.fps = fps
self.output_path = output_path
self.file_format = file_format
self.proc = None
self._out_chunks = []
self._err_chunks = []
self._out_thread = None
self._err_thread = None
def __init_ffmpeg(self, image_shape):
"""Initializes ffmpeg to write frames."""
import itertools # pylint: disable=g-import-not-at-top
from subprocess import Popen, PIPE # pylint: disable=g-import-not-at-top,g-multiple-import,g-importing-member
ffmpeg = "ffmpeg"
height, width, channels = image_shape
self.cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-vcodec", "rawvideo",
"-r", "%.02f" % self.fps,
"-s", "%dx%d" % (width, height),
"-pix_fmt", {1: "gray", 3: "rgb24"}[channels],
"-i", "-",
"-filter_complex", "[0:v]split[x][z];[x]fifo[w];[z]palettegen,fifo[y];"
"[w][y]paletteuse,fifo",
"-r", "%.02f" % self.fps,
"-f", self.file_format,
"-qscale", "0",
"-"
]
self.proc = Popen(
self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=-1
)
(self._out_thread, self._err_thread) = itertools.starmap(
self._start_reader_thread, [
(self.proc.stdout, self._out_chunks),
(self.proc.stderr, self._err_chunks)
]
)
def _start_reader_thread(self, stream, chunks):
"""Starts a thread for reading output from FFMPEG.
The thread reads consecutive chunks from the stream and saves them in
the given list.
Args:
stream: output stream of the FFMPEG process.
chunks: list to save output chunks to.
Returns:
Thread
"""
import io # pylint: disable=g-import-not-at-top
import threading # pylint: disable=g-import-not-at-top
def target():
while True:
chunk = stream.read(io.DEFAULT_BUFFER_SIZE)
if not chunk:
break
chunks.append(chunk)
thread = threading.Thread(target=target)
thread.start()
return thread
def write(self, frame, encoded_frame=None):
if self.proc is None:
self.__init_ffmpeg(frame.shape)
self.proc.stdin.write(frame.tostring())
def finish(self):
"""Finishes transconding and returns the video.
Returns:
bytes
Raises:
IOError: in case of transcoding error.
"""
if self.proc is None:
return None
self.proc.stdin.close()
for thread in (self._out_thread, self._err_thread):
thread.join()
(out, err) = [
b"".join(chunks) for chunks in (self._out_chunks, self._err_chunks)
]
self.proc.stdout.close()
self.proc.stderr.close()
if self.proc.returncode:
err = "\n".join([" ".join(self.cmd), err.decode("utf8")])
raise IOError(err)
del self.proc
self.proc = None
return out
def save_to_disk(self, output):
if self.output_path is None:
raise ValueError(
"This writer doesn't support saving to disk (output_path not "
"specified)."
)
with tf.gfile.Open(self.output_path, "w") as f:
f.write(output)
class BatchWholeVideoWriter(VideoWriter):
"""Helper class for writing videos in batch."""
def __init__(self, fps, path_template, file_format="gif"):
self.fps = fps
self.path_template = path_template
self.file_format = file_format
self.writers = None
def write(self, batch_frame, batch_encoded_frame=None):
del batch_encoded_frame
if self.writers is None:
self.writers = [
WholeVideoWriter( # pylint: disable=g-complex-comprehension
self.fps, self.path_template.format(i), self.file_format
)
for i in range(len(batch_frame))
]
for i, frame in enumerate(batch_frame):
self.writers[i].write(frame)
def finish(self):
outs = [w.finish() for w in self.writers]
return outs
def save_to_disk(self, outputs):
for (writer, output) in zip(self.writers, outputs):
writer.save_to_disk(output)
class IndividualFrameWriter(VideoWriter):
"""Helper class for writing individual video frames."""
def __init__(self, output_dir):
self.output_dir = output_dir
self._counter = 0
def write(self, frame=None, encoded_frame=None):
import os # pylint: disable=g-import-not-at-top
if encoded_frame is None:
raise ValueError("This writer only supports encoded frames.")
path = os.path.join(self.output_dir, "frame_%05d.png" % self._counter)
with tf.gfile.Open(path, "wb") as f:
f.write(encoded_frame)
self._counter += 1
================================================
FILE: tensor2tensor/layers/common_video_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for video utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensor2tensor.layers import common_video
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
tf.enable_eager_execution()
class CommonVideoTest(parameterized.TestCase, tf.test.TestCase):
def _run_scheduled_sample_func(self, func, var, batch_size):
ground_truth_x = list(range(1, batch_size+1))
generated_x = [-x for x in ground_truth_x]
ground_truth_x = tf.convert_to_tensor(ground_truth_x)
generated_x = tf.convert_to_tensor(generated_x)
ss_out = func(ground_truth_x, generated_x, batch_size, var)
output = self.evaluate([ground_truth_x, generated_x, ss_out])
return output
@test_utils.run_in_graph_and_eager_modes()
def testScheduledSampleProbStart(self):
ground_truth_x, _, ss_out = self._run_scheduled_sample_func(
common_video.scheduled_sample_prob, 1.0, 10)
self.assertAllEqual(ground_truth_x, ss_out)
@test_utils.run_in_graph_and_eager_modes()
def testScheduledSampleProbMid(self):
_, _, ss_out = self._run_scheduled_sample_func(
common_video.scheduled_sample_prob, 0.5, 1000)
positive_count = np.sum(ss_out > 0)
self.assertAlmostEqual(positive_count / 1000.0, 0.5, places=1)
@test_utils.run_in_graph_and_eager_modes()
def testScheduledSampleProbEnd(self):
_, generated_x, ss_out = self._run_scheduled_sample_func(
common_video.scheduled_sample_prob, 0.0, 10)
self.assertAllEqual(generated_x, ss_out)
@test_utils.run_in_graph_and_eager_modes()
def testScheduledSampleCountStart(self):
ground_truth_x, _, ss_out = self._run_scheduled_sample_func(
common_video.scheduled_sample_count, 10, 10)
self.assertAllEqual(ground_truth_x, ss_out)
@test_utils.run_in_graph_and_eager_modes()
def testScheduledSampleCountMid(self):
_, _, ss_out = self._run_scheduled_sample_func(
common_video.scheduled_sample_count, 5, 10)
positive_count = np.sum(ss_out > 0)
self.assertEqual(positive_count, 5)
@test_utils.run_in_graph_and_eager_modes()
def testScheduledSampleCountEnd(self):
_, generated_x, ss_out = self._run_scheduled_sample_func(
common_video.scheduled_sample_count, 0, 10)
self.assertAllEqual(generated_x, ss_out)
@test_utils.run_in_graph_and_eager_modes()
def testDynamicTileAndConcat(self):
# image = (1 X 4 X 4 X 1)
image = [[1, 2, 3, 4],
[2, 4, 5, 6],
[7, 8, 9, 10],
[7, 9, 10, 1]]
image_t = tf.expand_dims(tf.expand_dims(image, axis=0), axis=-1)
image_t = tf.cast(image_t, dtype=tf.float32)
# latent = (1 X 2)
latent = np.array([[90, 100]])
latent_t = tf.cast(tf.convert_to_tensor(latent), dtype=tf.float32)
tiled = common_video.tile_and_concat(
image_t, latent_t)
tiled_np, image_np = self.evaluate([tiled, image_t])
tiled_latent = tiled_np[0, :, :, -1]
self.assertAllEqual(tiled_np.shape, (1, 4, 4, 2))
self.assertAllEqual(tiled_np[:, :, :, :1], image_np)
self.assertAllEqual(
tiled_latent,
[[90, 90, 90, 90],
[100, 100, 100, 100],
[90, 90, 90, 90],
[100, 100, 100, 100]])
@test_utils.run_in_graph_mode_only()
def testGifSummary(self):
for c in (1, 3):
images_shape = (1, 12, 48, 64, c) # batch, time, height, width, channels
images = np.random.randint(256, size=images_shape).astype(np.uint8)
with self.test_session():
summary = common_video.gif_summary(
"gif", tf.convert_to_tensor(images), fps=10)
summary_string = summary.eval()
summary = tf.Summary()
summary.ParseFromString(summary_string)
self.assertEqual(1, len(summary.value))
self.assertTrue(summary.value[0].HasField("image"))
encoded = summary.value[0].image.encoded_image_string
self.assertEqual(encoded, common_video._encode_gif(images[0], fps=10)) # pylint: disable=protected-access
def check_if_patch_exists(self, videos, video_patches, num_frames):
"""Check that given patch is present in video."""
for video, video_patch in zip(videos, video_patches):
total_frames = len(video)
is_present = []
for start_ind in range(total_frames - num_frames + 1):
curr_patch = video[start_ind: start_ind + num_frames]
is_present.append(np.allclose(curr_patch, video_patch))
self.assertTrue(np.any(is_present))
def testBasicLstm(self):
"""Tests that the parameters of the LSTM are shared across time."""
with tf.Graph().as_default():
state = None
for _ in range(10):
inputs = tf.random_uniform(shape=(32, 16))
_, state = common_video.basic_lstm(
inputs, state, num_units=100, name="basic")
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
# 4 * ((100 + 16)*100 + 100) => 4 * (W_{xh} + W_{hh} + b)
self.assertEqual(num_params, 46800)
@parameterized.named_parameters(
("two_frames", 2), ("ten_frames", 10), ("default", -1))
def testExtractRandomVideoPatch(self, num_frames=2):
with tf.Graph().as_default():
rng = np.random.RandomState(0)
video_np = rng.randint(0, 255, size=(12, 20, 256, 256, 3))
video = tf.convert_to_tensor(video_np)
video_patch = common_video.extract_random_video_patch(
video, num_frames=num_frames)
with tf.Session() as sess:
video_patch_np = sess.run(video_patch)
if num_frames != -1:
self.assertEqual(video_patch_np.shape, (12, num_frames, 256, 256, 3))
self.check_if_patch_exists(video_np, video_patch_np, num_frames)
else:
self.assertTrue(np.allclose(video_np, video_patch_np))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/discretization.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discretization bottlenecks used to train discrete latent variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial # pylint: disable=g-importing-member
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import common_layers
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow_probability as tfp
from tensorflow.python.training import moving_averages # pylint: disable=g-direct-tensorflow-import
def project_hidden(x, projection_tensors, hidden_size, num_blocks):
"""Project encoder hidden state under num_blocks using projection tensors.
Args:
x: Encoder hidden state of shape [batch_size, latent_dim, hidden_size].
projection_tensors: Projection tensors used to project the hidden state.
hidden_size: Dimension of the latent space.
num_blocks: Number of blocks in DVQ.
Returns:
x_projected: Projected states of shape [batch_size, latent_dim, num_blocks,
hidden_size / num_blocks].
"""
batch_size, latent_dim, _ = common_layers.shape_list(x)
x = tf.reshape(x, shape=[1, -1, hidden_size])
x_tiled = tf.reshape(
tf.tile(x, multiples=[num_blocks, 1, 1]),
shape=[num_blocks, -1, hidden_size])
x_projected = tf.matmul(x_tiled, projection_tensors)
x_projected = tf.transpose(x_projected, perm=[1, 0, 2])
x_4d = tf.reshape(x_projected, [batch_size, latent_dim, num_blocks, -1])
return x_4d
def slice_hidden(x, hidden_size, num_blocks):
"""Slice encoder hidden state under num_blocks.
Args:
x: Encoder hidden state of shape [batch_size, latent_dim, hidden_size].
hidden_size: Dimension of the latent space.
num_blocks: Number of blocks in DVQ.
Returns:
Sliced states of shape [batch_size, latent_dim, num_blocks, block_dim].
"""
batch_size, latent_dim, _ = common_layers.shape_list(x)
block_dim = hidden_size // num_blocks
x_sliced = tf.reshape(x,
shape=[batch_size, latent_dim, num_blocks, block_dim])
return x_sliced
def nearest_neighbor(x,
means,
block_v_size,
random_top_k=1,
soft_em=False,
num_samples=1,
sum_over_latents=False,
summary=True):
"""Find the nearest element in means to elements in x.
Args:
x: Continuous encodings of shape [batch_size, latent_dim, num_blocks,
block_dim].
means: Embedding table of shape [num_blocks, block_v_size, block_dim].
block_v_size: Number of table entries per block.
random_top_k: Noisy top-k if this is bigger than 1.
soft_em: If True then use soft EM rather than hard EM.
num_samples: Number of samples to take in soft EM.
sum_over_latents: Whether to sum over non-batch dimensions when calculating
negative entropy loss. Used only when doing soft EM.
summary: If True then record summary histogram of entropies.
Returns:
Tensor with nearest element in mean encoded in one-hot notation
and distances.
"""
batch_size, latent_dim, num_blocks, block_dim = common_layers.shape_list(x)
x = tf.reshape(x, [batch_size * latent_dim, num_blocks, block_dim])
x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keep_dims=True)
means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keep_dims=True)
scalar_prod = tf.matmul(
tf.transpose(x, perm=[1, 0, 2]), tf.transpose(means, perm=[0, 2, 1]))
scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2])
dist = x_norm_sq + tf.transpose(
means_norm_sq, perm=[2, 0, 1]) - 2 * scalar_prod
# computing cluster probabilities
if soft_em:
num_blocks = common_layers.shape_list(dist)[1]
nearest_idx = tf.stack(
[
tf.multinomial(-dist[:, i, :], num_samples=num_samples)
for i in range(num_blocks)
],
axis=1)
nearest_hot = tf.one_hot(nearest_idx, depth=block_v_size)
neg_q_entropy = tf.reduce_sum(
nearest_hot * tf.expand_dims(tf.nn.log_softmax(-dist), 2), axis=2)
if sum_over_latents:
neg_q_entropy = tf.reduce_sum(neg_q_entropy, [1, 2])
neg_q_entropy = tf.reduce_mean(neg_q_entropy, axis=0)
nearest_hot = tf.reduce_mean(nearest_hot, axis=-2)
if summary:
tf.summary.histogram("neg_q_entropy", tf.reshape(neg_q_entropy, [-1]))
else:
neg_q_entropy = 0.
if random_top_k > 1:
_, top_k_idx = tf.nn.top_k(-dist, k=random_top_k)
nearest_idx = tf.gather(
top_k_idx,
tf.random_uniform(
[1], minval=0, maxval=random_top_k - 1, dtype=tf.int32),
axis=-1)
else:
nearest_idx = tf.argmax(-dist, axis=-1)
nearest_hot = tf.one_hot(nearest_idx, block_v_size)
return nearest_hot, neg_q_entropy
def embedding_lookup(x,
means,
num_blocks,
block_v_size,
bottleneck_kind="dvq",
random_top_k=1,
soft_em=False,
num_samples=1,
do_hard_gumbel_softmax=False,
temperature_warmup_steps=150000,
num_flows=0,
approximate_gs_entropy=False,
sum_over_latents=False):
"""Compute nearest neighbors and loss for training the embeddings via DVQ.
Args:
x: Continuous encodings of shape [batch_size, latent_dim, num_blocks,
block_dim].
means: Embedding table of shape [num_blocks, block_v_size, block_dim].
num_blocks: Number of blocks in DVQ.
block_v_size: Number of table entries per block.
bottleneck_kind: Discrete bottleneck type.
random_top_k: Noisy top-k if this is bigger than 1.
soft_em: If True then use soft EM rather than hard EM.
num_samples: Number of samples to use for soft EM.
do_hard_gumbel_softmax: Whether to use hard or soft Gumbel-Softmax samples
for gumbel-softmax-dvq bottleneck.
temperature_warmup_steps: Number of steps it takes to decay temperature to
0. Used only if bottleneck_kind is gumbel-softmax-dvq.
num_flows: Number of inverse autoregressive flows for gumbel-softmax-dvq
bottleneck.
approximate_gs_entropy: Whether to approximate the Gumbel-Softmax density
as a categorical distribution when calculating the sample entropy. Used
only if bottleneck_kind is gumbel-softmax-dvq.
sum_over_latents: Whether to sum over non-batch dimensions when calculating
negative entropy loss. Used only if soft EM or when bottleneck_kind is
gumbel-softmax-dvq.
Returns:
x_means_hot: The nearest neighbor in one hot form, with shape
[batch_size * latent_dim, num_blocks, block_v_size].
x_means: The nearest neighbor itself, with shape [batch_size * latent_dim,
num_blocks, block_dim].
q_loss: Scalar Tensor representing codebook loss.
e_loss: Scalar Tensor representing commitment loss.
neg_q_entropy: Scalar Tensor representing negative entropy of variational
approximation (0 if it is deterministic).
"""
if bottleneck_kind == "gumbel-softmax-dvq":
x_means_hot, neg_q_entropy = gumbel_softmax_nearest_neighbor_dvq(
x,
means,
block_v_size,
hard=do_hard_gumbel_softmax,
num_samples=num_samples,
temperature_warmup_steps=temperature_warmup_steps,
num_flows=num_flows,
approximate_gs_entropy=approximate_gs_entropy,
sum_over_latents=sum_over_latents)
else:
x_means_hot, neg_q_entropy = nearest_neighbor(
x,
means,
block_v_size,
random_top_k,
soft_em=soft_em,
num_samples=num_samples,
sum_over_latents=sum_over_latents)
x_means_hot_flat = tf.reshape(x_means_hot, [-1, num_blocks, block_v_size])
x_means = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 0, 2]), means)
x_means = tf.transpose(x_means, [1, 0, 2])
batch_size, latent_dim, num_blocks, block_dim = common_layers.shape_list(x)
x = tf.reshape(x, [batch_size * latent_dim, num_blocks, block_dim])
# Currently, we use the mean scaling for the commitment loss, as opposed to
# summing across all non-batch dimensions.
q_loss = tf.reduce_mean(tf.squared_difference(tf.stop_gradient(x), x_means))
e_loss = tf.reduce_mean(tf.squared_difference(x, tf.stop_gradient(x_means)))
return x_means_hot, x_means, q_loss, e_loss, neg_q_entropy
def bit_to_int(x_bit, num_bits, base=2):
"""Turn x_bit representing numbers bitwise (lower-endian) to int tensor.
Args:
x_bit: Tensor containing numbers in a particular base to be converted to
int.
num_bits: Number of bits in the representation.
base: Base of the representation.
Returns:
Integer representation of this number.
"""
x_l = tf.stop_gradient(tf.to_int32(tf.reshape(x_bit, [-1, num_bits])))
x_labels = [
x_l[:, i] * tf.to_int32(base)**tf.to_int32(i) for i in range(num_bits)]
res = sum(x_labels)
return tf.to_int32(tf.reshape(res, common_layers.shape_list(x_bit)[:-1]))
def int_to_bit(x_int, num_bits, base=2):
"""Turn x_int representing numbers into a bitwise (lower-endian) tensor.
Args:
x_int: Tensor containing integer to be converted into base notation.
num_bits: Number of bits in the representation.
base: Base of the representation.
Returns:
Corresponding number expressed in base.
"""
x_l = tf.to_int32(tf.expand_dims(x_int, axis=-1))
x_labels = [tf.floormod(
tf.floordiv(tf.to_int32(x_l), tf.to_int32(base)**i), tf.to_int32(base))
for i in range(num_bits)]
res = tf.concat(x_labels, axis=-1)
return tf.to_float(res)
def int_to_bit_embed(x_int, num_bits, embedding_size, base=2):
"""Turn x_int into a bitwise (lower-endian) tensor and embed densly."""
shape = common_layers.shape_list(x_int)
inputs = int_to_bit(x_int, num_bits, base=base)
inputs = tf.reshape(inputs, shape[:-1] + [shape[-1] * 8])
inputs = 2.0 * tf.to_float(inputs) - 1.0 # Move from 0/1 to -1/1.
return tf.layers.dense(inputs, embedding_size, name="int_to_bit_embed")
def embed(x,
hidden_size,
z_size,
filter_size,
bottleneck_kind="dvq",
soft_em=False,
num_blocks=2,
num_residuals=1,
block_v_size=None,
means=None,
name=None):
"""Embedding function that takes discrete latent and returns embedding.
Args:
x: Input to the discretization bottleneck.
hidden_size: Dimension of the latent state.
z_size: Number of bits, where discrete codes range from 1 to 2**z_size.
filter_size: Dimension to project embedding by. Used only if bottleneck_kind
is semhash.
bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
semhash, gumbel-softmax (Default: dvq).
soft_em: If True then it uses a multi-sample version of EM (Default: False).
num_blocks: Number of blocks in DVQ (Default: 2).
num_residuals: Number of residuals (Default: 1).
block_v_size: Number of embedding entries per block (Default: None).
means: The embedding table for dvq (Default: None).
name: Name for the bottleneck scope.
Returns:
Continuous embedding to be passed on to the decoder.
Raises:
ValueError: For unknown or missing arguments.
"""
with tf.variable_scope(name, default_name="embed", reuse=tf.AUTO_REUSE):
if bottleneck_kind == "semhash":
c = int_to_bit(x, z_size)
h1a = tf.layers.dense(c, filter_size, name="vch1a")
h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
h1 = h1a + h1b
h1 = tf.layers.dense(h1, hidden_size, name="vch_final_linear")
elif bottleneck_kind == "gumbel-softmax":
hot = tf.one_hot(x, 2**z_size)
h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
elif bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
if block_v_size is None:
raise ValueError("Bottleneck kind is dvq but block_v_size is None.")
if soft_em:
assert num_residuals == 1
x_hot_flat = tf.reshape(x, shape=[-1, num_blocks, block_v_size])
h1 = tf.matmul(tf.transpose(x_hot_flat, perm=[1, 0, 2]), means[0])
h1 = tf.transpose(h1, perm=[1, 0, 2])
new_shape = common_layers.shape_list(x)
new_shape[-1] = hidden_size
h1 = tf.reshape(h1, shape=new_shape)
else:
shape_x = common_layers.shape_list(x)
x_flat = tf.reshape(x, [-1, 1])
c = int_to_bit(x_flat, num_bits=z_size, base=2)
shape = common_layers.shape_list(c)
new_shape = shape
new_shape[-1] = num_residuals
new_shape.append(num_blocks)
new_shape.append(int(z_size / (num_residuals * num_blocks)))
c = tf.to_int32(tf.reshape(c, shape=new_shape))
h1_shape = shape_x
h1_shape.append(hidden_size)
h1 = tf.zeros(dtype=tf.float32, shape=h1_shape)
for i in range(num_residuals):
c_residual = bit_to_int(
c[:, :, i, :, :],
num_bits=int(z_size / (num_residuals * num_blocks)),
base=2)
c_hot = tf.one_hot(c_residual, depth=block_v_size, axis=-1)
c_hot_flat = tf.reshape(c_hot, shape=[-1, num_blocks, block_v_size])
h1_residual = tf.matmul(
tf.transpose(c_hot_flat, perm=[1, 0, 2]), means[i])
h1_residual = tf.transpose(h1_residual, perm=[1, 0, 2])
h1_residual = tf.reshape(h1_residual, shape=h1_shape)
h1 += h1_residual
elif bottleneck_kind == "rounding":
h1 = x
else:
raise ValueError("Unknown bottleneck kind.")
return h1
def vae(x, z_size, name=None):
"""Simple variational autoencoder without discretization.
Args:
x: Input to the discretization bottleneck.
z_size: Number of bits, where discrete codes range from 1 to 2**z_size.
name: Name for the bottleneck scope.
Returns:
Embedding function, latent, loss, mu and log_simga.
"""
with tf.variable_scope(name, default_name="vae"):
mu = tf.layers.dense(x, z_size, name="mu")
log_sigma = tf.layers.dense(x, z_size, name="log_sigma")
shape = common_layers.shape_list(x)
epsilon = tf.random_normal([shape[0], shape[1], 1, z_size])
z = mu + tf.exp(log_sigma / 2) * epsilon
kl = 0.5 * tf.reduce_mean(
tf.expm1(log_sigma) + tf.square(mu) - log_sigma, axis=-1)
free_bits = z_size // 4
kl_loss = tf.reduce_mean(tf.maximum(kl - free_bits, 0.0))
return z, kl_loss, mu, log_sigma
def top_k_softmax(x, k):
"""Calculate softmax(x), select top-k and rescale to sum to 1.
Args:
x: Input to softmax over.
k: Number of top-k to select.
Returns:
softmax(x) and maximum item.
"""
x = tf.nn.softmax(x)
top_x, _ = tf.nn.top_k(x, k=k + 1)
min_top = tf.reduce_min(top_x, axis=-1, keep_dims=True)
x = tf.nn.relu((x - min_top) + 1e-12)
x /= tf.reduce_sum(x, axis=-1, keep_dims=True)
return x, tf.reduce_max(top_x, axis=-1)
def gumbel_sample(shape):
"""Sample from the Gumbel distribution, protect from overflows.
Args:
shape: Shape of Gumbel samples.
Returns:
Noise drawn from Gumbel distribution.
"""
uniform_samples = tf.random_uniform(shape, minval=0.00001, maxval=0.99998)
return -tf.log(-tf.log(uniform_samples))
def gumbel_softmax(x,
z_size,
mode,
softmax_k=0,
temperature_warmup_steps=150000,
summary=True,
name=None):
"""Gumbel softmax discretization bottleneck.
Args:
x: Input to the discretization bottleneck.
z_size: Number of bits, where discrete codes range from 1 to 2**z_size.
mode: tf.estimator.ModeKeys.
softmax_k: If > 0 then do top-k softmax.
temperature_warmup_steps: Number of steps it takes to decay temperature to
0.
summary: Whether to write summaries.
name: Name for the bottleneck scope.
Returns:
Embedding function, discrete code, and loss.
"""
with tf.variable_scope(name, default_name="gumbel_softmax"):
m = tf.layers.dense(x, 2**z_size, name="mask")
if softmax_k > 0:
m, kl = top_k_softmax(m, softmax_k)
return m, m, 1.0 - tf.reduce_mean(kl)
logsm = tf.nn.log_softmax(m)
# Gumbel-softmax sample.
gumbel_samples = gumbel_sample(common_layers.shape_list(m))
steps = temperature_warmup_steps
gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
temperature = 1.2 - common_layers.inverse_lin_decay(steps)
# 10% of the time keep reasonably high temperature to keep learning.
temperature = tf.cond(
tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
m = tf.nn.softmax(m)
kl = -tf.reduce_max(logsm, axis=-1)
if summary:
tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
# Calculate the argmax and construct hot vectors.
maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
maxvhot = tf.stop_gradient(tf.one_hot(maxvec, 2**z_size))
# Add losses that prevent too few being used.
distrib = tf.reshape(logsm, [-1, 2**z_size]) * maxvhot
d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
d_variance = tf.reduce_mean(
tf.squared_difference(distrib, d_mean), axis=[0])
d_dev = -tf.reduce_mean(d_variance)
ret = s
if mode != tf_estimator.ModeKeys.TRAIN:
ret = tf.reshape(maxvhot, common_layers.shape_list(s)) # Just hot @eval.
return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
def discrete_bottleneck(inputs,
hidden_size,
z_size,
filter_size,
mode=None,
bottleneck_kind="dvq",
num_blocks=2,
num_residuals=1,
reshape_method="slice",
projection_tensors=None,
beta=0.25,
ema=True,
means=None,
ema_count=None,
ema_means=None,
epsilon=1e-5,
decay=0.999,
random_top_k=1,
soft_em=False,
num_samples=1,
softmax_k=0,
temperature_warmup_steps=150000,
do_hard_gumbel_softmax=False,
num_flows=0,
approximate_gs_entropy=False,
sum_over_latents=False,
discrete_mix=0.5,
noise_dev=1.,
startup_steps=50000,
summary=True,
name=None,
cond=True):
"""Discretization bottleneck.
Args:
inputs: Input to the bottleneck, a Tensor of shape [..., channels].
hidden_size: Dimension of the dense output.
z_size: Number of bits, where discrete codes range from 1 to 2**z_size.
filter_size: Filter size in the embedding function.
mode: tf.estimator.ModeKeys.
bottleneck_kind: Kind of discretization bottleneck. One of dense, dvq
(decomposed vector quantization), gumbel-softmax, gumbel-softmax-dvq,
semhash, or vae.
num_blocks: Number of blocks. Used only if bottleneck_kind is DVQ.
num_residuals: Number of residual units used to compute nearest
neighbors. Used only if bottleneck_kind is DVQ.
reshape_method: Method to reshape. Used only if bottleneck_kind is DVQ.
projection_tensors: If the reshape method is project, then these are the
tensors used to project.
beta: Scale factor for codebook loss and EMA. Used only if bottleneck_kind
is DVQ.
ema: Whether to update embeddings using exponential moving averages. Used
only if bottleneck_kind is DVQ.
means: The embedding table. Used only if ema is True.
ema_count: Table of counts for each embedding corresponding to how many
examples in a batch it was the closest to. Used only if ema is True.
ema_means: Exponentially averaged version of the embeddings. Used only if
ema is True.
epsilon: Small value to avoid dividing by zero in EMA update. Used only if
ema is True.
decay: Decay factor for the exponential moving average. Used only if ema is
True.
random_top_k: Noisy top-k. Used only if bottleneck_kind is DVQ.
soft_em: Whether to use soft EM or hard EM. Used only if bottleneck_kind is
DVQ.
num_samples: Number of samples for soft EM. Used only if soft_em is True.
softmax_k: If > 0 then do top-k softmax. Used only if bottleneck_kind
is gumbel-softmax.
temperature_warmup_steps: Number of steps it takes to decay temperature to
0. Used only if bottleneck_kind is gumbel-softmax or gumbel-softmax-dvq.
do_hard_gumbel_softmax: Whether to use hard or soft Gumbel-Softmax
samples. Used only if bottleneck_kind is gumbel-softmax-dvq.
num_flows: Number of inverse autoregresive flows. Used only if
bottleneck_kind is gumbel-softmax-dvq.
approximate_gs_entropy: Whether to approximate the Gumbel-Softmax density
as a categorical distribution when calculating the sample entropy. Used
only if bottleneck_kind is gumbel-softmax-dvq.
sum_over_latents: Whether to sum over all non-batch dimensions before
taking mean of entropy loss term. Used only if bottleneck kind is DVQ
or gumbel-softmax-dvq.
discrete_mix: Factor for mixing discrete and non-discrete input. Used only
if bottleneck_kind is semhash.
noise_dev: Noise stddev. Used only if bottleneck_kind is semhash.
startup_steps: Number of steps after which latent predictor is trained. Used
only if bottleneck_kind is semhash.
summary: Whether to write summaries.
name: Name for the bottleneck scope.
cond: A tf.bool condition on whether to update the codebook.
Returns:
outputs_dense: Tensor of shape [..., output_dim]. The output dimension is
hidden_size if bottleneck_kind is gumbel-softmax, DVQ; filter_size if
bottleneck_kind is dense, semhash, vae. If bottleneck_kind is DVQ,
outputs_dense represents the codebook (means) indexed by outputs_discrete.
outputs_discrete: Tensor of shape [...]. Discrete codes, each an index in
[0, 2**z_size). It uses the hot representation if soft_em is True.
extra_loss: Scalar Tensor. Sum of codebook and commitment losses if
bottleneck_kind is DVQ; else zero.
embed_fn: Function embed with arguments partially filled in.
neg_q_entropy: Scalar Tensor representing negative entropy of variational
approximation (0 if it is deterministic).
Raises:
ValueError: If projection_tensors is None for reshape_method project, or
ema_count or ema_means is None if ema is True, or unknown args.
"""
if bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
assert means is not None
if hidden_size % num_blocks != 0:
raise ValueError("num_blocks does not divide hidden size")
if z_size % num_residuals != 0:
raise ValueError("num_residuals does not divide embedding table size")
z_size_per_residual = int(z_size / num_residuals)
if z_size_per_residual % num_blocks != 0:
raise ValueError("num_blocks does not divide embedding table size")
block_v_size = 2**int(z_size_per_residual / num_blocks)
if ema:
if ema_count is None:
raise ValueError("ema_count is None but ema is True")
if ema_means is None:
raise ValueError("ema_means is None but ema is True")
else:
block_v_size = None
with tf.variable_scope(
name, default_name="discrete_bottleneck", reuse=tf.AUTO_REUSE):
embed_fn = partial(
embed,
hidden_size=hidden_size,
z_size=z_size,
filter_size=filter_size,
bottleneck_kind=bottleneck_kind,
soft_em=soft_em,
num_blocks=num_blocks,
num_residuals=num_residuals,
block_v_size=block_v_size,
means=means,
name=name)
if bottleneck_kind == "dense":
# Note discrete output is continuous here.
outputs_discrete = tf.layers.dense(inputs, z_size, name="vcc")
outputs_dense = tf.layers.dense(
outputs_discrete, filter_size, name="vch1")
extra_loss = tf.constant(0.0)
neg_q_entropy = tf.constant(0.0)
elif bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
inputs_3d = inputs
if len(inputs.shape) == 4:
inputs_3d = tf.squeeze(inputs, axis=2)
if reshape_method == "slice":
x_reshaped = slice_hidden(
inputs_3d, hidden_size=hidden_size, num_blocks=num_blocks)
elif reshape_method == "project":
if projection_tensors is None:
raise ValueError(
"Projection tensors is None for reshape_method project")
x_reshaped = project_hidden(
inputs_3d,
projection_tensors=projection_tensors,
hidden_size=hidden_size,
num_blocks=num_blocks)
else:
raise ValueError("Unknown reshape_method")
x_res = tf.reshape(x_reshaped,
[-1] + common_layers.shape_list(x_reshaped)[2:])
x_means_hot = []
x_means = 0
extra_loss = 0
for i in range(num_residuals):
x_means_hot_res, x_means_res, q_loss_res, e_loss_res, neg_q_entropy = (
embedding_lookup(
x_reshaped,
means=means[i],
num_blocks=num_blocks,
block_v_size=block_v_size,
bottleneck_kind=bottleneck_kind,
random_top_k=random_top_k,
soft_em=soft_em,
num_samples=num_samples,
temperature_warmup_steps=temperature_warmup_steps,
do_hard_gumbel_softmax=do_hard_gumbel_softmax,
num_flows=num_flows,
approximate_gs_entropy=approximate_gs_entropy,
sum_over_latents=sum_over_latents))
# Update the EMA variables.
if ema:
tf.logging.info("Using EMA with beta = {}".format(beta))
updated_ema_count_res = moving_averages.assign_moving_average(
ema_count[i],
tf.where(cond,
tf.reduce_sum(
tf.reshape(x_means_hot_res,
shape=[-1, num_blocks, block_v_size]),
axis=0), ema_count[i]),
decay,
zero_debias=False)
dw = tf.matmul(
tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
tf.transpose(x_res, perm=[1, 0, 2]))
updated_ema_means_res = moving_averages.assign_moving_average(
ema_means[i], tf.where(cond, dw, ema_means[i]),
decay, zero_debias=False)
n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True)
updated_ema_count_res = (
(updated_ema_count_res + epsilon) / (n + 2**z_size * epsilon) * n)
updated_ema_means_res = updated_ema_means_res / tf.expand_dims(
updated_ema_count_res, axis=-1)
with tf.control_dependencies([e_loss_res]):
update_means_res = tf.assign(means[i],
tf.where(cond,
updated_ema_means_res,
means[i]))
with tf.control_dependencies([update_means_res]):
extra_loss += beta * e_loss_res
else:
extra_loss += q_loss_res + beta * e_loss_res
# Update the residuals.
x_res -= x_means_res
x_means += x_means_res
x_means_hot.append(x_means_hot_res)
# Get the discrete latent representation.
x_means_hot = tf.stack(x_means_hot, axis=1)
x_means_idx = tf.argmax(x_means_hot, axis=-1)
# Get the binary representation.
x_means_bits = int_to_bit(
x_means_idx,
num_bits=int(z_size / (num_residuals * num_blocks)),
base=2)
shape = common_layers.shape_list(x_means_bits)
new_shape = shape[:-2]
new_shape[-1] = z_size
x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
outputs_discrete = bit_to_int(
tf.to_int32(x_means_bits), num_bits=z_size, base=2)
# Adjust shape of discrete outputs.
inputs_shape = common_layers.shape_list(inputs)
outputs_discrete = tf.reshape(outputs_discrete, inputs_shape[:-1])
# If we're using soft EM then set discretes to the hot representation.
if soft_em:
outputs_discrete = x_means_hot
outputs_discrete = tf.reshape(outputs_discrete,
inputs_shape[:-1] + [block_v_size])
# Reshape assuming hidden_size == inputs_shape[:-1].
x_means = tf.reshape(x_means, inputs_shape)
outputs_dense = inputs + tf.stop_gradient(x_means - inputs)
elif bottleneck_kind == "gumbel-softmax":
_, outputs_hot, extra_loss = gumbel_softmax(
inputs,
z_size=z_size,
mode=mode,
softmax_k=softmax_k,
temperature_warmup_steps=temperature_warmup_steps,
summary=summary,
name=name)
outputs_discrete = tf.argmax(outputs_hot, axis=-1)
outputs_dense = tf.layers.dense(
outputs_hot, hidden_size, name="dae_dense")
neg_q_entropy = tf.constant(0.0)
elif bottleneck_kind == "semhash":
outputs_discrete = tf.layers.dense(inputs, z_size, name="vcc")
y_clean = common_layers.saturating_sigmoid(outputs_discrete)
if summary:
tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
if noise_dev > 0 and mode == tf_estimator.ModeKeys.TRAIN:
noise = tf.truncated_normal(
common_layers.shape_list(outputs_discrete),
mean=0.0,
stddev=noise_dev)
y = common_layers.saturating_sigmoid(outputs_discrete + noise)
else:
y = y_clean
d = tf.to_float(tf.less(0.5, y))
y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
pd = common_layers.inverse_exp_decay(startup_steps * 2)
pd *= discrete_mix
pd = pd if mode == tf_estimator.ModeKeys.TRAIN else 1.0
c = tf.where(
tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
y_discrete, y)
outputs_dense_a = tf.layers.dense(c, filter_size, name="vch1a")
outputs_dense_b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
outputs_dense = outputs_dense_a + outputs_dense_b
outputs_dense = tf.layers.dense(outputs_dense, hidden_size,
name="vch_final_linear")
dx = tf.to_int32(tf.stop_gradient(d))
outputs_discrete = bit_to_int(dx, z_size)
extra_loss = tf.constant(0.0)
neg_q_entropy = tf.constant(0.0)
elif bottleneck_kind == "vae":
outputs_discrete, extra_loss, _, _ = vae(inputs, z_size, name="vae")
outputs_dense = tf.layers.dense(
outputs_discrete, filter_size, name="vch1")
neg_q_entropy = tf.constant(0.0)
else:
raise ValueError("Unknown discretization method.")
return outputs_dense, outputs_discrete, extra_loss, embed_fn, neg_q_entropy
def predict_bits_with_lstm(prediction_source, state_size, total_num_bits,
target_bits=None, extra_inputs=None,
bits_at_once=8, temperature=1.0, dropout=0.1):
"""Predict a sequence of bits (a latent) with LSTM, both training and infer.
Given a tensor on which the predictions are based (prediction_source), we use
a single-layer LSTM with state of size state_size to predict total_num_bits,
which we predict in groups of size bits_at_once. During training, we use
target_bits as input to the LSTM (teacher forcing) and return the target_bits
together with the prediction loss. During inference, we sample with the given
temperature and return the predicted sequence and loss 0.
Args:
prediction_source: a Tensor of shape [batch_size, ...] used to create
the initial state and the first input to the LSTM.
state_size: python integer, the size of the LSTM state.
total_num_bits: python integer, how many bits in total to predict.
target_bits: a tensor of shape [batch_size, total_num_bits] used during
training as the target to predict; each element should be -1 or 1.
extra_inputs: a Tensor [batch_size, total_num_bits // bits_at_once, d]
of additional inputs, passed as additional LSTM inputs.
bits_at_once: pytho integer, how many bits to predict at once.
temperature: python float, temperature used for sampling during inference.
dropout: float, the amount of dropout to aply during training (0.1 default).
Returns:
a pair (bits, loss) with the predicted bit sequence, which is a Tensor of
shape [batch_size, total_num_bits] with elements either -1 or 1, and a loss
used to train the predictions against the provided target_bits.
"""
with tf.variable_scope("predict_bits_with_lstm"):
# Layers and cell state creation.
lstm_cell = tf.nn.rnn_cell.LSTMCell(state_size)
discrete_predict = tf.layers.Dense(2**bits_at_once, name="discrete_predict")
discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")
batch_size = common_layers.shape_list(prediction_source)[0]
layer_pred = tf.layers.flatten(prediction_source)
first_lstm_input = tf.layers.dense(layer_pred, state_size, name="istate")
c_state = tf.layers.dense(layer_pred, state_size, name="cstate")
m_state = tf.layers.dense(layer_pred, state_size, name="mstate")
state = (c_state, m_state)
# Prediction mode if no targets are given.
if target_bits is None:
outputs = []
lstm_input = first_lstm_input
for i in range(total_num_bits // bits_at_once):
if extra_inputs is not None:
lstm_input = tf.concat([lstm_input, extra_inputs[:, i, :]], axis=1)
output, state = lstm_cell(lstm_input, state)
discrete_logits = discrete_predict(output)
discrete_samples = common_layers.sample_with_temperature(
discrete_logits, temperature)
outputs.append(tf.expand_dims(discrete_samples, axis=1))
lstm_input = discrete_embed(tf.one_hot(discrete_samples, 256))
outputs = tf.concat(outputs, axis=1)
outputs = int_to_bit(outputs, bits_at_once)
outputs = tf.reshape(outputs, [batch_size, total_num_bits])
return 2 * outputs - 1, 0.0
# Training mode, calculating loss.
assert total_num_bits % bits_at_once == 0
target_bits = tf.reshape(tf.maximum(tf.stop_gradient(target_bits), 0), [
batch_size, total_num_bits // bits_at_once, bits_at_once])
target_ints = bit_to_int(target_bits, bits_at_once)
tf.summary.histogram("target_integers", tf.reshape(target_ints, [-1]))
target_hot = tf.one_hot(target_ints, 2**bits_at_once, axis=-1)
target_embedded = discrete_embed(target_hot)
target_embedded = tf.nn.dropout(target_embedded, 1.0 - dropout)
teacher_input = tf.concat(
[tf.expand_dims(first_lstm_input, axis=1), target_embedded], axis=1)
outputs = []
for i in range(total_num_bits // bits_at_once):
lstm_input = teacher_input[:, i, :]
if extra_inputs is not None:
lstm_input = tf.concat([lstm_input, extra_inputs[:, i, :]], axis=1)
output, state = lstm_cell(lstm_input, state)
outputs.append(tf.expand_dims(output, axis=1))
outputs = tf.concat(outputs, axis=1)
outputs = tf.nn.dropout(outputs, 1.0 - dropout)
d_int_pred = discrete_predict(outputs)
pred_loss = tf.losses.sparse_softmax_cross_entropy(
logits=d_int_pred, labels=target_ints)
pred_loss = tf.reduce_mean(pred_loss)
return d_int_pred, pred_loss
# New API for discretization bottlenecks:
# * Each method is separate and provides 2 functions:
# * The [method]_bottleneck function returns discretized state.
# * The [method]_unbottleneck function moves from discretized state to dense.
def get_vq_codebook(codebook_size, hidden_size):
"""Get lookup table for VQ bottleneck."""
with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
means = tf.get_variable(
name="means",
shape=[codebook_size, hidden_size],
initializer=tf.uniform_unit_scaling_initializer())
ema_count = tf.get_variable(
name="ema_count",
shape=[codebook_size],
initializer=tf.constant_initializer(0),
trainable=False)
with tf.colocate_with(means):
ema_means = tf.get_variable(
name="ema_means",
initializer=tf.cond(
tf.is_variable_initialized(means),
means.read_value,
lambda: means.initial_value),
trainable=False)
return means, ema_means, ema_count
def vq_nearest_neighbor(x, means,
soft_em=False, num_samples=10, temperature=None):
"""Find the nearest element in means to elements in x."""
bottleneck_size = common_layers.shape_list(means)[0]
x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
scalar_prod = tf.matmul(x, means, transpose_b=True)
dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod
if soft_em:
x_means_idx = tf.multinomial(-dist, num_samples=num_samples)
x_means_hot = tf.one_hot(
x_means_idx, depth=common_layers.shape_list(means)[0])
x_means_hot = tf.reduce_mean(x_means_hot, axis=1)
else:
if temperature is None:
x_means_idx = tf.argmax(-dist, axis=-1)
else:
x_means_idx = tf.multinomial(- dist / temperature, 1)
x_means_idx = tf.squeeze(x_means_idx, axis=-1)
if (common_layers.should_generate_summaries() and
not common_layers.is_xla_compiled()):
tf.summary.histogram("means_idx", tf.reshape(x_means_idx, [-1]))
x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
x_means_hot_flat = tf.reshape(x_means_hot, [-1, bottleneck_size])
x_means = tf.matmul(x_means_hot_flat, means)
e_loss = tf.reduce_mean(tf.squared_difference(x, tf.stop_gradient(x_means)))
return x_means_hot, e_loss, dist
def vq_discrete_bottleneck(x,
bottleneck_bits,
beta=0.25,
decay=0.999,
epsilon=1e-5,
soft_em=False,
num_samples=10):
"""Simple vector quantized discrete bottleneck."""
bottleneck_size = 2**bottleneck_bits
x_means_hot, e_loss, _ = vq_body(
x,
bottleneck_size,
beta=beta,
decay=decay,
epsilon=epsilon,
soft_em=soft_em,
num_samples=num_samples)
return x_means_hot, e_loss
def vq_body(x,
codebook_size,
beta=0.25,
decay=0.999,
epsilon=1e-5,
soft_em=False,
num_samples=10,
temperature=None,
do_update=True):
"""Discretize each x into one of codebook_size codes."""
x_shape = common_layers.shape_list(x)
hidden_size = x_shape[-1]
means, ema_means, ema_count = get_vq_codebook(codebook_size, hidden_size)
x = tf.reshape(x, [-1, hidden_size])
x_means_hot, e_loss, distances = vq_nearest_neighbor(
x, means, soft_em=soft_em, num_samples=num_samples,
temperature=temperature)
def loss_with_update():
"""Update the ema variables and return loss triggering the update."""
updated_ema_count = moving_averages.assign_moving_average(
ema_count,
tf.reduce_sum(tf.reshape(x_means_hot, shape=[-1, codebook_size]),
axis=0),
decay,
zero_debias=False)
dw = tf.matmul(x_means_hot, x, transpose_a=True)
updated_ema_means = tf.identity(
moving_averages.assign_moving_average(
ema_means, dw, decay, zero_debias=False))
n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
updated_ema_count = (
(updated_ema_count + epsilon) / (n + codebook_size * epsilon) * n)
updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
with tf.control_dependencies([e_loss]):
update_means = means.assign(updated_ema_means)
with tf.control_dependencies([update_means]):
return beta * e_loss
# Loss, also do update if requested.
if do_update:
loss = loss_with_update()
else:
loss = tf.cond(do_update, loss_with_update, lambda: beta * e_loss)
d = tf.reshape(x_means_hot, x_shape[:-1] + [codebook_size])
return d, loss, distances
def vq_loss(x,
targets,
codebook_size,
beta=0.25,
decay=0.999,
epsilon=1e-5,
soft_em=False,
num_samples=10,
temperature=None,
do_update=True):
"""Compute the loss of large vocab tensors using a VQAE codebook.
Args:
x: Tensor of inputs to be quantized to nearest code
targets: Tensor of target indices to target codes
codebook_size: Size of quantization codebook
beta: scalar float for moving averages
decay: scalar float for moving averages
epsilon: scalar float for moving averages
soft_em: boolean, whether to apply a soft sampling procedure
num_samples: if soft_em, number of samples to take
temperature: temperature if we want to sample nearest neighbors or None
do_update: whether to update the means; True by default, can be a Tensor
Returns:
discrete_x: one-hot Tensor indicating which codebook element is closest to x
x_means: Tensor, on the forward pass: closest codebook element to x, on the
backwards pass: soft convex-combination of codebook elements by proximity
to x
target_means: the codebook elements corresponding to the targets
code_loss: loss driving x closer to its nearest codebook element
targets_loss: cross-entropy loss driving x closer to code corresponding to
target
"""
x_shape = common_layers.shape_list(x)
target_shape = common_layers.shape_list(targets)
hidden_size = x_shape[-1]
means, _, _ = get_vq_codebook(codebook_size, hidden_size)
x = tf.reshape(x, [-1, hidden_size])
targets = tf.reshape(targets, [-1])
one_hot_targets = tf.one_hot(targets, codebook_size)
target_means = tf.matmul(one_hot_targets, means)
discrete_x, code_loss, distances = vq_body(
x,
codebook_size,
beta=beta,
decay=decay,
epsilon=epsilon,
soft_em=soft_em,
num_samples=num_samples,
temperature=temperature,
do_update=do_update)
logits = -distances
targets_loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=targets)
targets_loss = tf.reduce_mean(targets_loss)
x_means = tf.matmul(discrete_x, means)
x_means = x + tf.stop_gradient(x_means - x)
discrete_x = tf.reshape(discrete_x, x_shape[:-1] + [codebook_size])
target_means = tf.reshape(target_means, target_shape + [hidden_size])
return discrete_x, x_means, target_means, code_loss, targets_loss
def vq_discrete_unbottleneck(x, hidden_size):
"""Simple undiscretization from vector quantized representation."""
x_shape = common_layers.shape_list(x)
x = tf.to_float(x)
bottleneck_size = common_layers.shape_list(x)[-1]
means, _, _ = get_vq_codebook(bottleneck_size, hidden_size)
result = tf.matmul(tf.reshape(x, [-1, x_shape[-1]]), means)
return tf.reshape(result, x_shape[:-1] + [hidden_size])
def gumbel_softmax_nearest_neighbor_dvq(x,
means,
block_v_size,
hard=False,
temperature_init=1.2,
num_samples=1,
temperature_warmup_steps=150000,
summary=True,
num_flows=0,
approximate_gs_entropy=False,
sum_over_latents=False):
"""Sample from Gumbel-Softmax and compute neighbors and losses.
Args:
x: A `float`-like `Tensor` of shape [batch_size, latent_dim, num_blocks,
block_dim] containing the latent vectors to be compared to the codebook.
means: Embedding table of shape [num_blocks, block_v_size, block_dim].
block_v_size: Number of discrete codes per block.
hard: Determines whether we take hard or soft Gumbel-Softmax samples
(Default: False).
temperature_init: Initial temperature used for Gumbel-Softmax samples,
after it which it decays to 0 (Default: 1.2).
num_samples: Number of samples drawn for each latent (Default: 1).
temperature_warmup_steps: Number of steps it takes to decay temperature to 0
(Default: 150000).
summary: When `True`, we save histogram summaries of the KL term (Default:
True).
num_flows: Number of inverse autoregressive flows with Gumbel-Softmax
samples.
approximate_gs_entropy: When `True`, we approximate Gumbel-Softmax
density as categorical when calculating sample entropy (Default: False).
sum_over_latents: Whether to sum over non-batch dimensions when calculating
negative entropy loss.
Returns:
x_means_assignments: A `float`-like `Tensor` containing the codebook
assignments, averaged over samples, with shape [batch_size * latent_dim,
num_blocks, block_v_size].
neg_q_entropy: The negative entropy of the variational distribution,
averaged over samples.
"""
batch_size, latent_dim, num_blocks, block_dim = common_layers.shape_list(x)
# Combine latent_dim and batch_size for computing distances.
x = tf.reshape(x, [-1, num_blocks, block_dim])
# Compute distances using (x - means)**2 = x**2 + means**2 - 2*x*means.
x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
means_norm_sq = tf.transpose(means_norm_sq, perm=[2, 0, 1])
scalar_prod = tf.matmul(
tf.transpose(x, perm=[1, 0, 2]), tf.transpose(means, perm=[0, 2, 1]))
scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2])
dist = x_norm_sq + means_norm_sq - 2 * scalar_prod
# IAF requires latents to have their own dimension, so reshape dist from
# [batch_size * latent_dim, num_blocks, block_v_size] to
# [batch_size * num_blocks, latent_dim, block_v_size].
dist = tf.reshape(dist, [batch_size, latent_dim, num_blocks, -1])
dist = tf.reshape(
tf.transpose(dist, perm=[0, 2, 1, 3]), [-1, latent_dim, block_v_size])
log_class_probs = tf.nn.log_softmax(-dist)
sample_shape = [num_samples] + common_layers.shape_list(dist)
gumbel_samples = gumbel_sample(sample_shape)
# Temperature decays linearly.
temperature = temperature_init - common_layers.inverse_lin_decay(
temperature_warmup_steps)
# 10% of the time keep reasonably high temperature to keep learning.
temperature = tf.cond(
tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
gumbel_softmax_samples = tf.nn.softmax(
(tf.expand_dims(log_class_probs, 0) + gumbel_samples) / temperature)
q_samples = tf.clip_by_value(gumbel_softmax_samples, 1e-6, 1 - 1e-6)
if approximate_gs_entropy:
q_dist = tfp.distributions.Multinomial(total_count=1.0, logits=-dist)
else:
q_dist = tfp.distributions.RelaxedOneHotCategorical(
temperature, logits=-dist)
# Take mean over samples to approximate entropy.
neg_q_entropy = tf.reduce_mean(q_dist.log_prob(q_samples), 0)
if summary:
tf.summary.histogram("neg_q_entropy", tf.reshape(neg_q_entropy, [-1]))
if sum_over_latents:
neg_q_entropy = tf.reshape(neg_q_entropy,
[batch_size, num_blocks, latent_dim])
neg_q_entropy = tf.reduce_sum(neg_q_entropy, [1, 2])
neg_q_entropy = tf.reduce_mean(neg_q_entropy)
if num_flows > 0:
hparams = iaf_hparams(hidden_size=512, filter_size=4096)
q_samples = tf.reshape(q_samples, [-1, latent_dim, block_v_size])
for flow in range(num_flows):
shifted_samples = tf.pad(q_samples, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
# Project samples from [batch_size, latent_size, block_v_size] to
# [batch_size, latent_size, hidden_size].
shifted_samples = common_layers.dense(shifted_samples,
hparams.hidden_size)
# TODO(vafa): Include masking as a flag.
mask = True
if mask:
attention_type = cia.AttentionType.LOCAL_1D
else:
attention_type = cia.AttentionType.GLOBAL
ffn_output = cia.transformer_decoder_layers(
inputs=shifted_samples,
encoder_output=None,
num_layers=6,
hparams=hparams,
attention_type=attention_type,
name="transformer_" + str(flow))
# Project samples back to [batch_size, latent_size, block_v_size].
ffn_output = common_layers.dense(ffn_output, block_v_size)
log_pi = tf.nn.log_softmax(ffn_output)
# Flow 1: Adding log_pi to q_samples and dividing by the temperature.
# Note that we drop the last dimension of q_samples for centered-softmax,
# which we can do without recalculating probabilities because the last
# dimension of log_pi and q_samples are deterministic given the others.
# Flow 2: Centered-softmax.
chained_bijectors = tfp.bijectors.Chain([
tfp.bijectors.SoftmaxCentered(),
tfp.bijectors.Affine(
shift=log_pi[:, :, :-1],
scale_identity_multiplier=1. / temperature)
])
q_samples = chained_bijectors.forward(q_samples[:, :, :-1])
log_det = chained_bijectors.inverse_log_det_jacobian(
q_samples, event_ndims=1)
log_det = tf.reshape(log_det,
[num_samples, batch_size, num_blocks, latent_dim])
if sum_over_latents:
log_det = tf.reduce_sum(log_det, axis=[2, 3])
neg_q_entropy += tf.reduce_mean(log_det)
q_samples = tf.reshape(
q_samples,
[num_samples, batch_size * num_blocks, latent_dim, block_v_size])
if hard:
x_means_idx = tf.argmax(q_samples, -1)
# Take average of one-hot vectors over samples.
x_means_hot = tf.reduce_mean(tf.one_hot(x_means_idx, block_v_size), 0)
x_means_assignments = (
tf.reduce_mean(q_samples, 0) +
tf.stop_gradient(x_means_hot - tf.reduce_mean(q_samples, 0)))
else:
x_means_assignments = tf.reduce_mean(gumbel_softmax_samples, 0)
# Reshape assignments to [batch_size * latent_dim, num_blocks,
# block_v_size]. We have to transpose between reshapes to make sure the
# dimensions have the correct interpretation.
x_means_assignments = tf.reshape(
x_means_assignments, [batch_size, num_blocks, latent_dim, block_v_size])
x_means_assignments = tf.transpose(x_means_assignments, [0, 2, 1, 3])
x_means_assignments = tf.reshape(
x_means_assignments, [batch_size * latent_dim, num_blocks, block_v_size])
return x_means_assignments, neg_q_entropy
def gumbel_softmax_discrete_bottleneck(x,
bottleneck_bits,
beta=0.25,
decay=0.999,
epsilon=1e-5,
temperature_warmup_steps=150000,
hard=False,
summary=True):
"""VQ-VAE using Gumbel-Softmax.
Different from `gumbel_softmax()` function as
this function calculates the KL by using the discrete entropy
instead of taking the argmax, and it also uses an exponential moving average
to update the codebook while the `gumbel_softmax()` function includes no
codebook update.
Args:
x: A `float`-like `Tensor` containing the latent vectors to be compared to
the codebook, whose squared difference is used as the Gumbel-Softmax
logits.
bottleneck_bits: An `int` that sets the size of the bottleneck in `log_2`.
beta: Beta factor for commitment loss (Default: 0.25).
decay: Decay factor for exponential moving average (Default: 0.999).
epsilon: Small value to avoid dividing by zero in EMA update
(Default: 1e-5).
temperature_warmup_steps: Number of steps it takes to decay temperature to 0
(Default: 150000).
hard: When `True`, we use hard Gumbel-Softmax samples and force
discrete latents by taking the argmax. When `False`, we use soft samples,
which we treat as codebook weights (Default: False).
summary: When `True`, we save histogram summaries of the KL term (Default:
True).
Returns:
x_means_assignments: A `float`-like `Tensor` containing the codebook
assignments. When `hard == True`, this is one-hot, containing the arg-max
of the Gumbel-Softmax samples (and we use the straightthrough gradient).
Otherwise, it contains the Gumbel-Softmax samples exactly, which are
values from the `(K-1)`-simplex where `K` is the bottleneck size.
loss: The loss, which is the sum of the KL between the Gumbel-Softmax and
the uniform prior and the commitment loss multiplied by the beta factor.
We approximate the KL by using the entropy of a categorical distribution
instead of the Gumbel Softmax.
"""
bottleneck_size = 2**bottleneck_bits
x_shape = common_layers.shape_list(x)
hidden_size = x_shape[-1]
means, ema_means, ema_count = get_vq_codebook(bottleneck_size, hidden_size)
x = tf.reshape(x, [-1, hidden_size])
bottleneck_size = common_layers.shape_list(means)[0]
x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
scalar_prod = tf.matmul(x, means, transpose_b=True)
dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod
class_probs = tf.nn.softmax(dist)
log_class_probs = tf.nn.log_softmax(dist)
gumbel_samples = gumbel_sample(common_layers.shape_list(dist))
steps = temperature_warmup_steps
gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
temperature = 1.2 - common_layers.inverse_lin_decay(steps)
# 10% of the time keep reasonably high temperature to keep learning.
temperature = tf.cond(
tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
gumbel_softmax_samples = tf.nn.softmax(
(log_class_probs + gumbel_samples) / temperature)
# Calculate KL between q and a uniform prior.
kl = tf.reduce_sum(
class_probs * (log_class_probs - tf.log(1.0 / bottleneck_size)), -1)
if summary:
tf.summary.histogram("KL", tf.reshape(kl, [-1]))
# Straight-through gradient estimation when we're using hard assignments.
if hard:
x_means_idx = tf.reshape(tf.argmax(gumbel_softmax_samples, axis=-1), [-1])
x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
x_means_assignments = gumbel_softmax_samples + tf.stop_gradient(
x_means_hot - gumbel_softmax_samples)
else:
x_means_assignments = gumbel_softmax_samples
x_means_assignments_flat = tf.reshape(x_means_assignments,
[-1, bottleneck_size])
x_means = tf.matmul(x_means_assignments_flat, means)
commitment_loss = tf.reduce_mean(
tf.squared_difference(x, tf.stop_gradient(x_means)))
# Update the ema variables.
updated_ema_count = moving_averages.assign_moving_average(
ema_count,
tf.reduce_sum(
tf.reshape(x_means_assignments, shape=[-1, bottleneck_size]), axis=0),
decay,
zero_debias=False)
dw = tf.matmul(x_means_assignments, x, transpose_a=True)
updated_ema_means = tf.identity(
moving_averages.assign_moving_average(
ema_means, dw, decay, zero_debias=False))
n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
updated_ema_count = (
(updated_ema_count + epsilon) / (n + bottleneck_size * epsilon) * n)
updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
with tf.control_dependencies([commitment_loss]):
update_means = means.assign(updated_ema_means)
with tf.control_dependencies([update_means]):
loss = beta * commitment_loss
# Add KL loss.
loss += tf.reduce_mean(kl)
x_means_assignments = tf.reshape(x_means_assignments,
x_shape[:-1] + [bottleneck_size])
return x_means_assignments, loss
def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
discretize_warmup_steps, mode):
"""Simple discretization through tanh, flip bottleneck_noise many bits."""
x = tf.layers.dense(x, bottleneck_bits, name="tanh_discrete_bottleneck")
d0 = tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x))) - 1.0
if mode == tf_estimator.ModeKeys.TRAIN:
x += tf.truncated_normal(
common_layers.shape_list(x), mean=0.0, stddev=0.2)
x = tf.tanh(x)
d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
if mode == tf_estimator.ModeKeys.TRAIN:
noise = tf.random_uniform(common_layers.shape_list(x))
noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
d *= noise
d = common_layers.mix(d, x, discretize_warmup_steps,
mode == tf_estimator.ModeKeys.TRAIN)
return d, d0
def tanh_discrete_unbottleneck(x, hidden_size):
"""Simple un-discretization from tanh."""
x = tf.layers.dense(x, hidden_size, name="tanh_discrete_unbottleneck")
return x
def isemhash_bottleneck(x,
bottleneck_bits,
bottleneck_noise,
discretize_warmup_steps,
mode,
isemhash_noise_dev=0.5,
isemhash_mix_prob=0.5):
"""Improved semantic hashing bottleneck."""
with tf.variable_scope("isemhash_bottleneck"):
x = tf.layers.dense(x, bottleneck_bits, name="dense")
y = common_layers.saturating_sigmoid(x)
if isemhash_noise_dev > 0 and mode == tf_estimator.ModeKeys.TRAIN:
noise = tf.truncated_normal(
common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev)
y = common_layers.saturating_sigmoid(x + noise)
d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y)
d = 2.0 * d - 1.0 # Move from [0, 1] to [-1, 1].
if mode == tf_estimator.ModeKeys.TRAIN: # Flip some bits.
noise = tf.random_uniform(common_layers.shape_list(x))
noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
d *= noise
d = common_layers.mix(
d,
2.0 * y - 1.0,
discretize_warmup_steps,
mode == tf_estimator.ModeKeys.TRAIN,
max_prob=isemhash_mix_prob)
return d, 0.0
def isemhash_unbottleneck(x, hidden_size, isemhash_filter_size_multiplier=1.0):
"""Improved semantic hashing un-bottleneck."""
filter_size = int(hidden_size * isemhash_filter_size_multiplier)
x = 0.5 * (x - 1.0) # Move from [-1, 1] to [0, 1].
with tf.variable_scope("isemhash_unbottleneck"):
h1a = tf.layers.dense(x, filter_size, name="hidden1a")
h1b = tf.layers.dense(1.0 - x, filter_size, name="hidden1b")
h2 = tf.layers.dense(tf.nn.relu(h1a + h1b), filter_size, name="hidden2")
return tf.layers.dense(tf.nn.relu(h2), hidden_size, name="final")
def parametrized_bottleneck(x, hparams):
"""Meta-function calling all the above bottlenecks with hparams."""
if hparams.bottleneck_kind == "tanh_discrete":
d, _ = tanh_discrete_bottleneck(
x, hparams.bottleneck_bits, hparams.bottleneck_noise * 0.5,
hparams.discretize_warmup_steps, hparams.mode)
return d, 0.0
if hparams.bottleneck_kind == "isemhash":
return isemhash_bottleneck(
x, hparams.bottleneck_bits, hparams.bottleneck_noise * 0.5,
hparams.discretize_warmup_steps, hparams.mode,
hparams.isemhash_noise_dev, hparams.isemhash_mix_prob)
if hparams.bottleneck_kind == "vq":
return vq_discrete_bottleneck(x, hparams.bottleneck_bits, hparams.vq_beta,
hparams.vq_decay, hparams.vq_epsilon)
if hparams.bottleneck_kind == "em":
return vq_discrete_bottleneck(
x,
hparams.bottleneck_bits,
hparams.vq_beta,
hparams.vq_decay,
hparams.vq_epsilon,
soft_em=True,
num_samples=hparams.vq_num_samples)
if hparams.bottleneck_kind == "gumbel_softmax":
return gumbel_softmax_discrete_bottleneck(
x,
hparams.bottleneck_bits,
hparams.vq_beta,
hparams.vq_decay,
hparams.vq_epsilon,
hparams.temperature_warmup_steps,
hard=False,
summary=True)
raise ValueError(
"Unsupported hparams.bottleneck_kind %s" % hparams.bottleneck_kind)
def parametrized_unbottleneck(x, hidden_size, hparams):
"""Meta-function calling all the above un-bottlenecks with hparams."""
if hparams.bottleneck_kind == "tanh_discrete":
return tanh_discrete_unbottleneck(x, hidden_size)
if hparams.bottleneck_kind == "isemhash":
return isemhash_unbottleneck(x, hidden_size,
hparams.isemhash_filter_size_multiplier)
if hparams.bottleneck_kind in ["vq", "em", "gumbel_softmax"]:
return vq_discrete_unbottleneck(x, hidden_size)
raise ValueError(
"Unsupported hparams.bottleneck_kind %s" % hparams.bottleneck_kind)
def iaf_hparams(hidden_size=512, filter_size=4096):
"""Create hyperpameters for inverse autoregressive flows.
Args:
hidden_size: Width of attention layers and neural network output layer.
filter_size: Hidden layer width for neural network.
Returns:
hparams: Hyperpameters with basic presets for inverse autoregressive flows.
"""
hparams = common_hparams.basic_params1()
# Attention hyperparameters.
hparams.hidden_size = hidden_size
hparams.add_hparam("attention_key_channels", None)
hparams.add_hparam("attention_value_channels", None)
hparams.add_hparam("num_heads", 4)
hparams.add_hparam("attention_dropout", 0.1)
hparams.add_hparam("shared_rel", False)
hparams.add_hparam("block_width", 1)
hparams.add_hparam("block_length", 1)
hparams.add_hparam("q_filter_width", 1)
hparams.add_hparam("kv_filter_width", 1)
# Preprocessing and postprocesing hyperparameters.
hparams.layer_preprocess_sequence = "n"
hparams.layer_prepostprocess_dropout = 0.1
hparams.norm_type = "layer"
hparams.norm_epsilon = 1e-06
hparams.layer_prepostprocess_dropout_broadcast_dims = ""
hparams.layer_postprocess_sequence = "da"
# Feedforward neural network hyperparameters.
hparams.add_hparam("filter_size", filter_size)
hparams.add_hparam("ffn_layer", "conv_hidden_relu")
hparams.add_hparam("relu_dropout", 0.1)
return hparams
================================================
FILE: tensor2tensor/layers/discretization_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for discretization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.layers import discretization
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
tf.enable_eager_execution()
class DiscretizationTest(tf.test.TestCase):
"""Tests for discretization layers."""
def setUp(self):
tf.set_random_seed(1234)
np.random.seed(123)
@test_utils.run_in_graph_and_eager_modes()
def testBitToIntZeros(self):
x_bit = tf.zeros(shape=[1, 10], dtype=tf.float32)
x_int = tf.zeros(shape=[1], dtype=tf.int32)
diff = discretization.bit_to_int(x_bit, num_bits=10) - x_int
d = self.evaluate(diff)
self.assertEqual(d, 0)
@test_utils.run_in_graph_and_eager_modes()
def testBitToIntOnes(self):
x_bit = tf.ones(shape=[1, 3], dtype=tf.float32)
x_int = 7 * tf.ones(shape=[1], dtype=tf.int32)
diff = discretization.bit_to_int(x_bit, num_bits=3) - x_int
d = self.evaluate(diff)
self.assertEqual(d, 0)
@test_utils.run_in_graph_and_eager_modes()
def testIntToBitZeros(self):
x_bit = tf.zeros(shape=[1, 10], dtype=tf.float32)
x_int = tf.zeros(shape=[1], dtype=tf.int32)
diff = discretization.int_to_bit(x_int, num_bits=10) - x_bit
d = self.evaluate(diff)
self.assertTrue(np.all(d == 0))
@test_utils.run_in_graph_and_eager_modes()
def testIntToBitOnes(self):
x_bit = tf.ones(shape=[1, 3], dtype=tf.float32)
x_int = 7 * tf.ones(shape=[1], dtype=tf.int32)
diff = discretization.int_to_bit(x_int, num_bits=3) - x_bit
d = self.evaluate(diff)
self.assertTrue(np.all(d == 0))
@test_utils.run_in_graph_and_eager_modes()
def testProjectHidden(self):
hidden_size = 60
block_dim = 20
num_blocks = 3
x = tf.zeros(shape=[1, 1, hidden_size], dtype=tf.float32)
projection_tensors = tf.random_normal(
shape=[num_blocks, hidden_size, block_dim], dtype=tf.float32)
x_projected = discretization.project_hidden(x, projection_tensors,
hidden_size, num_blocks)
x_projected_eval = self.evaluate(x_projected)
self.assertEqual(np.shape(x_projected_eval), (1, 1, num_blocks, block_dim))
self.assertTrue(np.all(x_projected_eval == 0))
@test_utils.run_in_graph_and_eager_modes()
def testSliceHiddenZeros(self):
hidden_size = 60
block_dim = 20
num_blocks = 3
x = tf.zeros(shape=[1, 1, hidden_size], dtype=tf.float32)
x_sliced = discretization.slice_hidden(x, hidden_size, num_blocks)
x_sliced_eval = self.evaluate(x_sliced)
self.assertEqual(np.shape(x_sliced_eval), (1, 1, num_blocks, block_dim))
self.assertTrue(np.all(x_sliced_eval == 0))
@test_utils.run_in_graph_and_eager_modes()
def testSliceHiddenOnes(self):
hidden_size = 60
block_dim = 20
num_blocks = 3
x = tf.ones(shape=[1, 1, hidden_size], dtype=tf.float32)
x_sliced = discretization.slice_hidden(x, hidden_size, num_blocks)
x_sliced_eval = self.evaluate(x_sliced)
self.assertEqual(np.shape(x_sliced_eval), (1, 1, num_blocks, block_dim))
self.assertTrue(np.all(x_sliced_eval == 1))
@test_utils.run_in_graph_and_eager_modes()
def testNearestNeighbors(self):
x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
x = tf.reshape(x, [1, 1, 2, 3])
means = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [9, 9, 9]], dtype=tf.float32)
means = tf.stack([means, means], axis=0)
x_means_hot, _ = discretization.nearest_neighbor(
x, means, block_v_size=4)
x_means_hot_test = np.array([[0, 1, 0, 0], [1, 0, 0, 0]])
x_means_hot_test = np.expand_dims(x_means_hot_test, axis=0)
x_means_hot_eval = self.evaluate(x_means_hot)
self.assertEqual(np.shape(x_means_hot_eval), (1, 2, 4))
self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))
@test_utils.run_in_graph_mode_only()
def testGetVQBottleneck(self):
bottleneck_bits = 2
bottleneck_size = 2**bottleneck_bits
hidden_size = 3
means, _, ema_count = discretization.get_vq_codebook(
bottleneck_size, hidden_size)
assign_op = means.assign(tf.zeros(shape=[bottleneck_size, hidden_size]))
means_new, _, _ = discretization.get_vq_codebook(bottleneck_size,
hidden_size)
with self.test_session() as sess:
tf.global_variables_initializer().run()
sess.run(assign_op)
self.assertTrue(np.all(sess.run(means_new) == 0))
self.assertTrue(np.all(sess.run(ema_count) == 0))
@test_utils.run_in_graph_and_eager_modes()
def testVQNearestNeighbors(self):
x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
means = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [9, 9, 9]], dtype=tf.float32)
x_means_hot, _, _ = discretization.vq_nearest_neighbor(x, means)
x_means_hot_test = np.array([[0, 1, 0, 0], [1, 0, 0, 0]])
x_means_hot_eval = self.evaluate(x_means_hot)
self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))
def testVQDiscreteBottleneck(self):
x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
x_means_hot, _ = discretization.vq_discrete_bottleneck(x, bottleneck_bits=2)
self.evaluate(tf.global_variables_initializer())
x_means_hot_eval = self.evaluate(x_means_hot)
self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
def testVQDiscreteUnbottlenck(self):
x = tf.constant([[1, 0, 0, 0], [0, 0, 1, 0]], dtype=tf.int32)
x_means = discretization.vq_discrete_unbottleneck(x, hidden_size=3)
self.evaluate(tf.global_variables_initializer())
x_means_eval = self.evaluate(x_means)
self.assertEqual(np.shape(x_means_eval), (2, 3))
def testGumbelSoftmaxDiscreteBottleneck(self):
x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, tf.constant(1))
x_means_hot, _ = discretization.gumbel_softmax_discrete_bottleneck(
x, bottleneck_bits=2)
self.evaluate(tf.global_variables_initializer())
x_means_hot_eval = self.evaluate(x_means_hot)
self.assertEqual(np.shape(x_means_hot_eval), (2, 4))
@test_utils.run_in_graph_mode_only()
def testDiscreteBottleneckVQ(self):
hidden_size = 60
z_size = 4
x = tf.zeros(shape=[100, 1, hidden_size], dtype=tf.float32)
with tf.variable_scope("test", reuse=tf.AUTO_REUSE):
means = tf.get_variable("means",
shape=[1, 1, 2**z_size, hidden_size],
initializer=tf.constant_initializer(0.),
dtype=tf.float32)
ema_count = []
ema_count_i = tf.get_variable(
"ema_count",
[1, 2**z_size],
initializer=tf.constant_initializer(0),
trainable=False)
ema_count.append(ema_count_i)
ema_means = []
with tf.colocate_with(means):
ema_means_i = tf.get_variable("ema_means",
initializer=means.initialized_value()[0],
trainable=False)
ema_means.append(ema_means_i)
x_means_dense, x_means_hot, _, _, _ = discretization.discrete_bottleneck(
x, hidden_size, z_size, 32, means=means, num_blocks=1,
ema_means=ema_means, ema_count=ema_count, name="test")
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
x_means_dense_eval, x_means_hot_eval = sess.run(
[x_means_dense, x_means_hot])
means_eval = sess.run(means)
self.assertEqual(x_means_dense_eval.shape, (100, 1, hidden_size))
self.assertEqual(x_means_hot_eval.shape, (100, 1))
self.assertTrue(np.all(means_eval == np.zeros(
(1, 1, 2**z_size, hidden_size))))
@test_utils.run_in_graph_mode_only()
def testDiscreteBottleneckVQCond(self):
hidden_size = 60
z_size = 4
x = tf.zeros(shape=[100, 1, hidden_size], dtype=tf.float32)
with tf.variable_scope("test2", reuse=tf.AUTO_REUSE):
means = tf.get_variable("means",
shape=[1, 1, 2**z_size, hidden_size],
initializer=tf.constant_initializer(0.),
dtype=tf.float32)
ema_count = []
ema_count_i = tf.get_variable(
"ema_count",
[1, 2**z_size],
initializer=tf.constant_initializer(0),
trainable=False)
ema_count.append(ema_count_i)
ema_means = []
with tf.colocate_with(means):
ema_means_i = tf.get_variable("ema_means",
initializer=means.initialized_value()[0],
trainable=False)
ema_means.append(ema_means_i)
cond = tf.cast(0.0, tf.bool)
x_means_dense, x_means_hot, _, _, _ = discretization.discrete_bottleneck(
x, hidden_size, z_size, 32, means=means, num_blocks=1, cond=cond,
ema_means=ema_means, ema_count=ema_count, name="test2")
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
x_means_dense_eval, x_means_hot_eval = sess.run(
[x_means_dense, x_means_hot])
means_eval = sess.run(means)
self.assertEqual(x_means_dense_eval.shape, (100, 1, hidden_size))
self.assertEqual(x_means_hot_eval.shape, (100, 1))
self.assertAllClose(means_eval, np.zeros((1, 1, 2**z_size,
hidden_size)))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/latent_layers.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for latent variable models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import transformer_layers
from tensor2tensor.utils import beam_search
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow_probability as tfp
DO_SUMMARIES = True
def compress_self_attention_layer(x, hparams, name=None):
"""Attend function."""
with tf.variable_scope(name, default_name="compress_self_attention"):
x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x, hparams),
None,
None,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size, hparams.num_heads,
hparams.attention_dropout)
res = common_layers.layer_postprocess(x, y, hparams)
return tf.reshape(res, xshape)
def compute_nats_and_bits_per_dim(data_dim,
latent_dim,
average_reconstruction,
average_prior):
"""Computes negative ELBO, which is an upper bound on the negative likelihood.
Args:
data_dim: int-like indicating data dimensionality.
latent_dim: int-like indicating latent dimensionality.
average_reconstruction: Scalar Tensor indicating the reconstruction cost
averaged over all data dimensions and any data batches.
average_prior: Scalar Tensor indicating the negative log-prior probability
averaged over all latent dimensions and any data batches.
Returns:
Tuple of scalar Tensors, representing the nats and bits per data dimension
(e.g., subpixels) respectively.
"""
with tf.name_scope(None, default_name="compute_nats_per_dim"):
data_dim = tf.cast(data_dim, average_reconstruction.dtype)
latent_dim = tf.cast(latent_dim, average_prior.dtype)
negative_log_likelihood = data_dim * average_reconstruction
negative_log_prior = latent_dim * average_prior
negative_elbo = negative_log_likelihood + negative_log_prior
nats_per_dim = tf.divide(negative_elbo, data_dim, name="nats_per_dim")
bits_per_dim = tf.divide(nats_per_dim, tf.log(2.), name="bits_per_dim")
return nats_per_dim, bits_per_dim
def multinomial_sample(x, vocab_size=None, sampling_method="random",
temperature=1.0):
"""Multinomial sampling from a n-dimensional tensor.
Args:
x: Tensor of shape [..., vocab_size]. Parameterizes logits of multinomial.
vocab_size: Number of classes in multinomial distribution.
sampling_method: String, "random" or otherwise deterministic.
temperature: Positive float.
Returns:
Tensor of shape [...].
"""
vocab_size = vocab_size or common_layers.shape_list(x)[-1]
if sampling_method == "random" and temperature > 0.0:
samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]) / temperature, 1)
else:
samples = tf.argmax(x, axis=-1)
reshaped_samples = tf.reshape(samples, common_layers.shape_list(x)[:-1])
return reshaped_samples
def ae_latent_softmax(latents_pred, latents_discrete_hot, vocab_size, hparams):
"""Latent prediction and loss.
Args:
latents_pred: Tensor of shape [..., depth].
latents_discrete_hot: Tensor of shape [..., vocab_size].
vocab_size: an int representing the vocab size.
hparams: HParams.
Returns:
sample: Tensor of shape [...], a sample from a multinomial distribution.
loss: Tensor of shape [...], the softmax cross-entropy.
"""
with tf.variable_scope("latent_logits"):
latents_logits = tf.layers.dense(latents_pred, vocab_size,
name="logits_dense")
if hparams.logit_normalization:
latents_logits *= tf.rsqrt(1e-8 +
tf.reduce_mean(tf.square(latents_logits)))
loss = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=latents_discrete_hot, logits=latents_logits)
# TODO(trandustin): tease this out from ae_latent_softmax.
# we use just the loss portion to anchor prior / encoder on text.
sample = multinomial_sample(latents_logits,
vocab_size,
hparams.sampling_method,
hparams.sampling_temp)
return sample, loss
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams):
"""Samples from the latent space in the autoencoder.
Args:
latents_dense_in: Tensor of shape [batch, length_q, ...]. Only the shape of
its first two dimensions are used. length_q is the latent length, which is
height * width * hparams.num_latents / (2**hparams.num_compress_steps).
inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Encodings
to attend to in decoder.
ed: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q,
length_kv]. Encoder-decoder attention bias.
embed: Callable which embeds discrete latent hot-vectors and a hidden size
and returns dense vectors.
hparams: HParams.
Returns:
Tensor of shape [batch, length].
"""
def symbols_to_logits_fn(ids):
"""Go from ids to logits."""
ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros.
latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]])
with tf.variable_scope(tf.get_variable_scope(), reuse=False):
latents_dense = embed(
tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits),
hparams.hidden_size)
latents_pred = transformer_latent_decoder(
latents_dense, inputs, ed, hparams, name="latent_prediction")
logits = tf.layers.dense(
latents_pred, 2**hparams.bottleneck_bits, name="logits_dense")
current_output_position = common_layers.shape_list(ids)[1] - 1
logits = logits[:, current_output_position, :]
return logits
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
length = tf.shape(latents_dense_in)[1]
ids, _, _ = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids,
1,
length,
2**hparams.bottleneck_bits,
alpha=0.0,
eos_id=-1,
stop_early=False)
res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam.
return res[:, 1:] # Remove the added all-zeros from ids.
def residual_block_layer(inputs, hparams):
"""Residual block over inputs.
Runs a residual block consisting of
conv: kernel_size x kernel_size
conv: 1x1
dropout, add and normalize according to hparams.layer_postprocess_sequence.
Args:
inputs: Tensor of shape [batch, height, width, hparams.hidden_size].
hparams: HParams.
Returns:
Tensor of shape [batch, height, width, hparams.hidden_size].
"""
kernel = (hparams.res_kernel_size, hparams.res_kernel_size)
x = inputs
for i in range(hparams.num_res_layers):
with tf.variable_scope("res_conv_%d" % i):
# kernel_size x kernel_size conv block
y = common_layers.conv_block(
common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"),
hparams.hidden_size, [((1, 1), kernel)],
strides=(1, 1),
padding="SAME",
name="residual_conv")
# 1x1 conv block
y = common_layers.conv_block(
y,
hparams.hidden_size, [((1, 1), (1, 1))],
strides=(1, 1),
padding="SAME",
name="residual_dense")
x = common_layers.layer_postprocess(x, y, hparams)
return x
def compress_encoder(inputs,
hparams,
strides=(2, 2),
kernel_size=(3, 3),
name=None):
"""Encoder that compresses 2-D inputs by 2**num_compress_steps.
Args:
inputs: Tensor of shape [batch, height, width, channels].
hparams: HParams.
strides: Tuple, strides for conv block.
kernel_size: Tuple, kernel window size for conv block.
name: string, variable scope.
Returns:
Tensor of shape [batch, latent_length, hparams.hidden_size], where
latent_length is
hparams.num_latents * (height*width) / 2**(hparams.num_compress_steps).
"""
with tf.variable_scope(name, default_name="compress"):
x = inputs
for i in range(hparams.num_compress_steps // 2):
with tf.variable_scope("compress_conv_%d" % i):
y = common_layers.conv_block(
common_layers.layer_norm(
x, hparams.hidden_size, name="lnorm"),
hparams.hidden_size,
dilation_rates_and_kernel_sizes=[((1, 1), kernel_size)],
strides=strides,
padding="SAME",
name="compress_conv_%d" % i)
y = tf.nn.dropout(y, 1.0 - hparams.dropout)
if hparams.do_compress_attend:
y = compress_self_attention_layer(
x, hparams, name="compress_selfatt_%d" % i)
y += x
x = y
x = residual_block_layer(x, hparams)
# If using multiple copies of latents, blow up the hidden size and then
# reshape to increase by num_latents.
shape_x = common_layers.shape_list(x)
x = tf.layers.dense(x,
hparams.num_latents * hparams.hidden_size,
name=name + "_dense")
return tf.reshape(x, [shape_x[0],
shape_x[1] * shape_x[2] * hparams.num_latents,
hparams.hidden_size])
def compress_encoder_2d(x, hparams, name=None):
"""Encoder that compresses 2-D inputs by 2**num_compress_steps.
Args:
x: Tensor of shape [batch, height, width, channels].
hparams: HParams.
name: string, variable scope.
Returns:
Tensor of shape [batch, latent_length, hparams.hidden_size], where
latent_length is
hparams.num_latents * (height*width) / 2**(hparams.num_compress_steps).
"""
return compress_encoder(
x,
hparams,
strides=(2, 2),
kernel_size=(hparams.kernel_size, hparams.kernel_size),
name=name)
def compress_encoder_1d(x, hparams, name=None):
"""Encoder that compresses 1-D inputs by 2**num_compress_steps.
Args:
x: Tensor of shape [batch, length, channels].
hparams: HParams.
name: string, variable scope.
Returns:
Tensor of shape [batch, latent_length, hparams.hidden_size], where
latent_length is
hparams.num_latents * length / 2**hparams.num_compress_steps.
"""
x = tf.expand_dims(x, axis=2)
return compress_encoder(x,
hparams,
strides=(2, 1),
kernel_size=(hparams.kernel_size, 1),
name=name)
def decompress_decoder(inputs,
hparams,
strides=(2, 2),
kernel=(3, 3),
name=None):
"""Decoder that decompresses 2-D inputs by 2**num_compress_steps.
Args:
inputs: Tensor of shape [batch, compress_height, compress_width, channels].
hparams: HParams.
strides: Tuple, strides for conv block.
kernel: Tuple, kernel window size for conv block.
name: string, variable scope.
Returns:
Tensor of shape [batch, height, width, hparams.hidden_size].
"""
with tf.variable_scope(name, default_name="decompress"):
x = inputs
x = tf.layers.dense(x, hparams.hidden_size, name=name + "_dense")
x = residual_block_layer(x, hparams)
for i in range(hparams.num_compress_steps // 2):
j = hparams.num_compress_steps // 2 - i - 1
with tf.variable_scope(name + "_%d" % j):
if hparams.do_decompress_attend:
y = compress_self_attention_layer(
x, hparams, name="decompress_selfatt")
x += y
y = tf.layers.conv2d_transpose(
x,
hparams.hidden_size,
kernel,
strides=strides,
padding="SAME",
activation=tf.nn.relu if i > 0 else None,
name="decompress_conv")
x = y
return x
def decompress_decoder_2d(x, hparams, name=None):
"""Decoder that decompresses 2-D inputs by 2**num_compress_steps.
Args:
x: Tensor of shape [batch, compress_height, compress_width, channels].
hparams: HParams.
name: string, variable scope.
Returns:
Tensor of shape [batch, height, width, hparams.hidden_size].
"""
return decompress_decoder(x, hparams,
strides=(2, 2),
kernel=(hparams.kernel_size, hparams.kernel_size),
name=name)
def decompress_decoder_1d(x, hparams, name=None):
"""Decoder that decompresses 1-D inputs by 2**num_compress_steps.
Args:
x: Tensor of shape [batch, compress_length, channels].
hparams: HParams.
name: string, variable scope.
Returns:
Tensor of shape [batch, length, hparams.hidden_size].
"""
x = tf.expand_dims(x, axis=2)
output = decompress_decoder(x, hparams,
strides=(2, 1),
kernel=(hparams.kernel_size, 1),
name=name)
return tf.squeeze(output, axis=2)
def transformer_text_encoder(inputs,
target_space,
hparams,
name=None):
"""Transformer text encoder over inputs with unmasked full attention.
Args:
inputs: Tensor of shape [batch, length, 1, hparams.hidden_size].
target_space: int. Used for encoding inputs under a target space id.
hparams: HParams.
name: string, variable scope.
Returns:
encoder_output: Tensor of shape [batch, length, hparams.hidden_size].
ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias
for any padded tokens.
"""
with tf.variable_scope(name, default_name="transformer_text_encoder"):
inputs = common_layers.flatten4d3d(inputs)
[
encoder_input,
encoder_self_attention_bias,
ed,
] = transformer_layers.transformer_prepare_encoder(
inputs, target_space=target_space, hparams=hparams)
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
encoder_output = transformer_layers.transformer_encoder(
encoder_input, encoder_self_attention_bias, hparams)
return encoder_output, ed
def transformer_image_decoder(targets,
encoder_output,
ed_attention_bias,
hparams,
name=None):
"""Transformer image decoder over targets with local attention.
Args:
targets: Tensor of shape [batch, ...], and whose size is batch * height *
width * hparams.num_channels * hparams.hidden_size.
encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size].
ed_attention_bias: Tensor which broadcasts with shape [batch,
hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias.
hparams: HParams.
name: string, variable scope.
Returns:
Tensor of shape [batch, height, width * hparams.num_channels,
hparams.hidden_size].
"""
with tf.variable_scope(name, default_name="transformer_dec"):
batch_size = common_layers.shape_list(targets)[0]
targets = tf.reshape(targets, [batch_size,
hparams.img_len,
hparams.img_len,
hparams.num_channels * hparams.hidden_size])
decoder_input, _, _ = cia.prepare_decoder(targets, hparams)
decoder_output = cia.transformer_decoder_layers(
decoder_input,
encoder_output,
hparams.num_decoder_layers or hparams.num_hidden_layers,
hparams,
attention_type=hparams.dec_attention_type,
encoder_decoder_attention_bias=ed_attention_bias,
name="decoder")
decoder_output = tf.reshape(decoder_output,
[batch_size,
hparams.img_len,
hparams.img_len * hparams.num_channels,
hparams.hidden_size])
return decoder_output
def transformer_latent_decoder(x,
encoder_output,
ed_attention_bias,
hparams,
name=None):
"""Transformer decoder over latents using latent_attention_type.
Args:
x: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the
latent length, which is
height * width * hparams.num_latents / (2**hparams.num_compress_steps).
encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size].
ed_attention_bias: Tensor which broadcasts with shape [batch,
hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias.
hparams: HParams.
name: string, variable scope.
Returns:
Tensor of shape [batch, length_q, hparams.hidden_size].
"""
with tf.variable_scope(name, default_name="transformer_latent_dec"):
batch_size = common_layers.shape_list(x)[0]
compressed_img_len = (hparams.img_len //
2**(hparams.num_compress_steps // 2))
x = tf.reshape(x, [batch_size,
compressed_img_len,
compressed_img_len * hparams.num_latents,
hparams.hidden_size])
decoder_input, _, _ = cia.prepare_decoder(x, hparams)
decoder_output = cia.transformer_decoder_layers(
decoder_input,
encoder_output,
hparams.num_latent_layers or hparams.num_hidden_layers,
hparams,
attention_type=hparams.latent_attention_type,
encoder_decoder_attention_bias=ed_attention_bias,
name="decoder")
decoder_output = tf.reshape(decoder_output,
[batch_size,
compressed_img_len**2 * hparams.num_latents,
hparams.hidden_size])
return decoder_output
def bottleneck_layer(inputs,
hparams,
name="discrete_bottleneck"):
"""Computes latents given inputs (typically, compressed targets)."""
[
latents_dense,
latents_discrete,
extra_loss,
embed_fn,
_,
] = hparams.bottleneck(inputs=inputs,
filter_size=hparams.compress_filter_size,
name=name,
mode=hparams.mode)
if DO_SUMMARIES:
tf.summary.histogram("discrete_latents",
tf.reshape(latents_discrete, [-1]))
return latents_dense, latents_discrete, extra_loss, embed_fn
def latent_prediction_model(inputs,
ed_attention_bias,
latents_discrete,
latents_dense,
hparams,
vocab_size=None,
name=None):
"""Transformer-based latent prediction model.
It is an autoregressive decoder over latents_discrete given inputs.
Args:
inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Inputs to
attend to for the decoder on latents.
ed_attention_bias: Tensor which broadcasts with shape [batch,
hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias.
latents_discrete: Tensor of shape [batch, length_q, vocab_size].
One-hot latents to compute log-probability of given inputs.
latents_dense: Tensor of shape [batch, length_q, hparams.hidden_size].
length_q is the latent length, which is
height * width * hparams.num_latents / (2**hparams.num_compress_steps).
hparams: HParams.
vocab_size: int or None. If None, it is 2**hparams.bottleneck_bits.
name: string, variable scope.
Returns:
latents_pred: Tensor of shape [batch, length_q, hparams.hidden_size].
latents_pred_loss: Tensor of shape [batch, length_q].
"""
with tf.variable_scope(name, default_name="latent_prediction"):
if hparams.mode != tf_estimator.ModeKeys.PREDICT:
latents_pred = transformer_latent_decoder(tf.stop_gradient(latents_dense),
inputs,
ed_attention_bias,
hparams,
name)
if vocab_size is None:
vocab_size = 2**hparams.bottleneck_bits
if not hparams.soft_em:
# TODO(trandustin): latents_discrete is not one-hot from
# discrete_bottleneck unless hparams.soft_em is True. Refactor.
latents_discrete = tf.one_hot(latents_discrete, depth=vocab_size)
_, latent_pred_loss = ae_latent_softmax(
latents_pred, tf.stop_gradient(latents_discrete), vocab_size, hparams)
return latents_pred, latent_pred_loss
def transformer_autoencoder(inputs,
targets,
target_space,
hparams,
cache=None,
predict_mask=1.0):
"""Auto-encoder using a Transformer decoder and a prior over latent sequences.
Args:
inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None.
targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2
dimensions denoting sequence length.
target_space: int. Used for encoding inputs under a target space id.
hparams: HParams.
cache: Tensor of shape [batch, length] or None.
predict_mask: Tensor masking whether to use gold targets or predictions.
Returns:
decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting
pre-logit activations. After a transformation (`top` in `T2TModel`), it is
used with targets to compute the "training" (reconstruction) loss.
losses: dict of str to Tensors. There are three loss terms: "extra",
"extra_loss", and "latent_pred". The first is hard-coded to 0. The latter
two are Tensors of shape [batch].
cache: Tensor of shape [batch, length], either the same as cache, or newly
computed if the cache input is None.
"""
original_targets_shape = common_layers.shape_list(targets)
batch_size = original_targets_shape[0]
if len(original_targets_shape) == 4:
compress_fn = compress_encoder_2d
decompress_fn = decompress_decoder_2d
else:
compress_fn = compress_encoder_1d
decompress_fn = decompress_decoder_1d
ed_attention_bias = None
if inputs is not None:
inputs, ed_attention_bias = transformer_text_encoder(
inputs, target_space, hparams, name="input_encoder")
losses = {"extra": 0.,
"extra_loss": 0.,
"latent_pred": 0.}
if hparams.mode != tf_estimator.ModeKeys.PREDICT:
targets_compressed = compress_fn(targets, hparams, name="compress")
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
scale = common_layers.inverse_exp_decay(hparams.startup_steps)
else:
scale = 1.0
scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale))
latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer(
targets_compressed, hparams)
extra_loss = scale * tf.reduce_mean(extra_loss)
_, latents_pred_loss = latent_prediction_model(
inputs, ed_attention_bias, latents_discrete, latents_dense, hparams,
name="latent_pred")
latent_time = tf.less(hparams.mask_startup_steps,
tf.to_int32(tf.train.get_global_step()))
latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss)
latents_pred_loss *= tf.to_float(latent_time)
# Apply dropout noise for each data point and time step.
latents_dense_shape = common_layers.shape_list(latents_dense)
latents_dense = tf.nn.dropout(
latents_dense,
keep_prob=1 - hparams.latent_dropout,
noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1])
# TODO(trandustin): Can we combine extra and extra_loss?
losses = {"extra": 0.,
"extra_loss": extra_loss,
"latent_pred": latents_pred_loss}
else:
# Set the latent length, which is num_latents times the number of latent
# pixels. The number of latent pixels is determined by a compression factor
# on the number of image pixels.
latent_len = ((hparams.img_len * hparams.img_len * hparams.num_latents) /
(2**hparams.num_compress_steps))
_, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams)
latents_dense = tf.zeros([batch_size, latent_len, 1, hparams.hidden_size])
if cache is None:
cache = ae_latent_sample_beam(latents_dense,
inputs,
ed_attention_bias,
embed_fn,
hparams)
cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
latents_dense = embed_fn(cache_one_hot, hparams.hidden_size)
if len(original_targets_shape) == 4:
compressed_img_len = (hparams.img_len //
2**(hparams.num_compress_steps // 2))
latents_dense = tf.reshape(latents_dense,
[batch_size,
compressed_img_len,
compressed_img_len,
hparams.num_latents * hparams.hidden_size])
latents_dense = decompress_fn(latents_dense, hparams, name="decompress")
latents_dense = tf.reshape(
latents_dense,
[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])
if hparams.use_gold_targets:
if hparams.mode == tf_estimator.ModeKeys.PREDICT:
masking = predict_mask
else:
masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)
mask = tf.less(masking,
tf.random_uniform(common_layers.shape_list(targets)[:-1]))
mask = tf.expand_dims(tf.to_float(mask), 2)
latents_dense = mask * targets + (1.0 - mask) * latents_dense
latents_dense = tf.reshape(latents_dense, original_targets_shape)
if hparams.decode_autoregressive:
decoder_output = transformer_image_decoder(
latents_dense, inputs, ed_attention_bias, hparams, name="decoder")
else:
decoder_output = latents_dense
return decoder_output, losses, cache
def iaf_flow(one_hot_assignments,
scale_weights,
scale_bias,
num_codes,
summary=True,
name=None):
"""Performs a single IAF flow using scale and normalization transformations.
Args:
one_hot_assignments: Assignments Tensor with shape [num_samples, batch_size,
latent_size, num_codes].
scale_weights: Tensor corresponding to lower triangular matrix used to
autoregressively generate scale matrix from assignments. To ensure the
lower-triangular matrix has length of latent_size, scale_weights should
be a rank-one tensor with size latent_size * (latent_size + 1) / 2.
scale_bias: Bias tensor to be added to scale tensor, with shape
[latent_size, num_codes]. If scale weights are zero, initialize scale_bias
to be log(exp(1.) / 2. - 1) so initial transformation is identity.
num_codes: Number of codes in codebook.
summary: Whether to save summaries.
name: String used for name scope.
Returns:
flow_output: Transformed one-hot assignments.
inverse_log_det_jacobian: Inverse log deteriminant of Jacobian corresponding
to transformation.
"""
with tf.name_scope(name, default_name="iaf"):
# Pad the one_hot_assignments by zeroing out the first latent dimension and
# shifting the rest down by one (and removing the last dimension).
padded_assignments = tf.pad(
one_hot_assignments, [[0, 0], [0, 0], [1, 0], [0, 0]])[:, :, :-1, :]
scale_bijector = tfp.distributions.bijectors.Affine(
scale_tril=tfp.math.fill_triangular(scale_weights))
scale = scale_bijector.forward(
tf.transpose(padded_assignments, [0, 1, 3, 2]))
# Transpose the bijector output since it performs a batch matmul.
scale = tf.transpose(scale, [0, 1, 3, 2])
scale = tf.nn.softplus(scale)
scale = scale + tf.nn.softplus(scale_bias[tf.newaxis, tf.newaxis, ...])
# Don't need last dimension since the transformation keeps it constant.
scale = scale[..., :-1]
z = one_hot_assignments[..., :-1]
unnormalized_probs = tf.concat([z * scale,
one_hot_assignments[..., -1, tf.newaxis]],
axis=-1)
normalizer = tf.reduce_sum(unnormalized_probs, axis=-1)
flow_output = unnormalized_probs / (normalizer[..., tf.newaxis])
inverse_log_det_jacobian = (-tf.reduce_sum(tf.log(scale), axis=-1)
+ num_codes * tf.log(normalizer))
if summary:
tf.summary.histogram("iaf/scale", tf.reshape(scale, [-1]))
tf.summary.histogram("iaf/inverse_log_det_jacobian",
tf.reshape(inverse_log_det_jacobian, [-1]))
return flow_output, inverse_log_det_jacobian
================================================
FILE: tensor2tensor/layers/latent_layers_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for layers in latent variable models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import discretization
from tensor2tensor.layers import latent_layers
from tensor2tensor.models import transformer
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
tf.enable_eager_execution()
def imagetransformer_latent_tiny():
"""Tiny set of hparams for a latent image model."""
hparams = transformer.transformer_small()
hparams.batch_size = 2
hparams.num_hidden_layers = 3
hparams.hidden_size = 16
hparams.filter_size = 32
hparams.compress_filter_size = 64
hparams.ffn_layer = "conv_hidden_relu"
hparams.layer_prepostprocess_dropout = 0.2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.dropout = 0.3
hparams.pos = "timing"
hparams.num_encoder_layers = 1
hparams.num_decoder_layers = 2
hparams.use_pad_remover = False
hparams.add_hparam("logit_normalization", True)
hparams.add_hparam("bottleneck_kind", "dvq")
hparams.add_hparam("bottleneck_bits", 4)
hparams.add_hparam("num_residuals", 1)
hparams.add_hparam("use_gold_targets", False)
hparams.add_hparam("do_compress_attend", False)
hparams.add_hparam("do_decompress_attend", False)
hparams.add_hparam("drop_inputs", False)
hparams.add_hparam("num_compress_steps", 2)
hparams.add_hparam("startup_steps", 10000)
hparams.add_hparam("mask_startup_steps", 50000)
hparams.add_hparam("latent_dropout", 0.0)
hparams.add_hparam("decode_autoregressive", False)
hparams.add_hparam("vq_beta", 0.25)
hparams.add_hparam("vq_epsilon", 1e-5)
hparams.add_hparam("vq_decay", 0.999)
hparams.add_hparam("ema", False)
hparams.add_hparam("soft_em", True)
hparams.add_hparam("num_samples", 1)
hparams.add_hparam("num_latent_layers", 2)
hparams.add_hparam("num_res_layers", 2)
hparams.add_hparam("res_kernel_size", 3)
hparams.add_hparam("num_blocks", 1)
hparams.add_hparam("reshape_method", "slice")
hparams.add_hparam("shared_rel", False)
hparams.add_hparam("block_size", 1)
hparams.add_hparam("kernel_size", 3)
hparams.add_hparam("img_len", 8)
hparams.add_hparam("num_channels", 1)
hparams.add_hparam("local_and_global_att", False)
hparams.add_hparam("block_length", 32)
hparams.add_hparam("block_width", 128)
hparams.add_hparam("dec_attention_type", cia.AttentionType.LOCAL_1D)
hparams.add_hparam("latent_attention_type", cia.AttentionType.GLOBAL)
hparams.add_hparam("block_raster_scan", False)
hparams.add_hparam("num_latents", 1)
hparams.add_hparam("q_filter_width", 1)
hparams.add_hparam("kv_filter_width", 1)
return hparams
class LatentLayersTest(tf.test.TestCase):
@test_utils.run_in_graph_and_eager_modes()
def testComputeBitsAndNats(self):
reconstruction_loss = tf.random_uniform(())
prior_loss = tf.random_uniform(())
data_dim = tf.random_uniform((), maxval=1000, dtype=tf.int32)
latent_dim = tf.random_uniform((), maxval=1000, dtype=tf.int32)
nats_per_dim, bits_per_dim = latent_layers.compute_nats_and_bits_per_dim(
data_dim,
latent_dim,
reconstruction_loss,
prior_loss)
nats_per_dim_py, bits_per_dim_conv_py = self.evaluate(
[nats_per_dim, bits_per_dim * tf.log(2.)])
self.assertAllClose(nats_per_dim_py, bits_per_dim_conv_py)
@test_utils.run_in_graph_and_eager_modes()
def testTransformerAutoencoder(self):
hparams = imagetransformer_latent_tiny()
hparams.mode = tf_estimator.ModeKeys.TRAIN
block_dim = int(hparams.hidden_size // hparams.num_blocks)
block_v_size = 2**(hparams.bottleneck_bits /
(hparams.num_residuals * hparams.num_blocks))
block_v_size = int(block_v_size)
means = tf.get_variable(
name="means",
shape=[hparams.num_residuals,
hparams.num_blocks,
block_v_size,
block_dim],
initializer=tf.uniform_unit_scaling_initializer())
hparams.bottleneck = functools.partial(
discretization.discrete_bottleneck,
hidden_size=hparams.hidden_size,
z_size=hparams.bottleneck_bits,
filter_size=hparams.filter_size,
startup_steps=hparams.startup_steps,
bottleneck_kind=hparams.bottleneck_kind,
num_blocks=hparams.num_blocks,
num_residuals=hparams.num_residuals,
reshape_method=hparams.reshape_method,
beta=hparams.vq_beta,
decay=hparams.vq_decay,
soft_em=hparams.soft_em,
num_samples=hparams.num_samples,
epsilon=hparams.vq_epsilon,
ema=hparams.ema,
means=means)
inputs = None
batch_size = hparams.batch_size
targets = tf.random_uniform([batch_size,
hparams.img_len,
hparams.img_len,
hparams.hidden_size],
minval=-1., maxval=1.)
target_space_id = None
tf.train.create_global_step()
decoder_output, losses, cache = latent_layers.transformer_autoencoder(
inputs, targets, target_space_id, hparams)
self.assertEqual(set(losses), {"extra", "extra_loss", "latent_pred"})
self.evaluate(tf.global_variables_initializer())
decoder_output_, extra_loss_, latent_pred_ = self.evaluate(
[decoder_output, losses["extra_loss"], losses["latent_pred"]])
self.assertEqual(decoder_output_.shape, (batch_size,
hparams.img_len,
hparams.img_len,
hparams.hidden_size))
self.assertEqual(extra_loss_.shape, (batch_size,))
self.assertEqual(latent_pred_.shape, (batch_size,))
self.assertAllGreaterEqual(extra_loss_, 0.)
self.assertAllGreaterEqual(latent_pred_, 0.)
self.assertEqual(cache, None)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/message_passing_attention.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import expert_utils
import tensorflow.compat.v1 as tf
def multihead_graph_attention(query_antecedent,
memory_antecedent,
bias,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
dropout_rate,
image_shapes=None,
attention_type="edge_vector",
name="multihead_graph_attention",
save_weights_to=None,
make_image_summary=True,
dropout_broadcast_dims=None,
adjacency_matrix=None,
num_edge_types=5,
vars_3d=False,
**kwargs):
"""Multihead scaled-dot-product attention with input/output transformations.
Args:
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
bias: bias Tensor (see attention_bias())
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
attention_type: a string, either "dot_product", "dot_product_relative",
"local_mask_right", "local_unmasked", "masked_dilated_1d",
"unmasked_dilated_1d", graph, or any attention function
with the signature (query, key, value, **kwargs)
name: an optional string.
save_weights_to: an optional dictionary to capture attention weights
for vizualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
adjacency_matrix: an optional tensor of shape [batch, len_q, len_q]
containing edge vectors for attention
num_edge_types: number of edge types, an int
vars_3d: use 3-dimensional variables for input/output transformations
**kwargs (dict): Parameters for the attention function
Returns:
The result of the attention transformation. The output shape is
[batch_size, length_q, output_depth]
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
vars_3d_num_heads = num_heads if vars_3d else None
with tf.variable_scope(
name,
default_name="multihead_attention",
values=[query_antecedent, memory_antecedent]):
q, k, v = common_attention.compute_qkv(
query_antecedent,
memory_antecedent,
total_key_depth,
total_value_depth,
vars_3d_num_heads=vars_3d_num_heads)
q = common_attention.split_heads(q, num_heads)
k = common_attention.split_heads(k, num_heads)
v = common_attention.split_heads(v, num_heads)
key_depth_per_head = total_key_depth // num_heads
if not vars_3d:
q *= key_depth_per_head**-0.5
additional_returned_value = None
if callable(attention_type): # Generic way to extend multihead_attention
x = attention_type(q, k, v, **kwargs)
if isinstance(x, tuple):
x, additional_returned_value = x # Unpack
elif attention_type == "edge_vector":
x = graph_attention(
q,
k,
v,
bias,
dropout_rate,
image_shapes,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=dropout_broadcast_dims,
adjacency_matrix=adjacency_matrix,
num_edge_types=num_edge_types)
x = common_attention.combine_heads(x)
# Set last dim specifically.
x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])
if vars_3d:
o_var = tf.get_variable(
"o", [num_heads, total_value_depth // num_heads, output_depth])
o_var = tf.reshape(o_var, [total_value_depth, output_depth])
x = tf.tensordot(x, o_var, axes=1)
else:
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform")
if additional_returned_value is not None:
return x, additional_returned_value
return x
@expert_utils.add_name_scope()
def make_edge_vectors(adjacency_matrix,
num_edge_types,
depth,
name=None):
"""Gets edge vectors for the edge types in the adjacency matrix.
Args:
adjacency_matrix: A [batch, num_nodes, num_nodes, num_edge_types] tensor.
num_edge_types: Number of different edge types
depth: Number of channels
name: A optional string name for scoping
Returns:
A [batch, num_nodes, num_nodes, depth] vector of tensors
"""
with tf.variable_scope(name, default_name="edge_vectors"):
att_adj_vectors_shape = [num_edge_types, depth]
adjacency_matrix_shape = common_layers.shape_list(adjacency_matrix)
adj_vectors = (
tf.get_variable(
"adj_vectors",
att_adj_vectors_shape,
initializer=tf.random_normal_initializer(0, depth**-0.5)) *
(depth**0.5))
att_adj_vectors = tf.matmul(
tf.reshape(tf.to_float(adjacency_matrix), [-1, num_edge_types]),
adj_vectors)
# Reshape to be [batch, num_nodes, num_nodes, depth].
att_adj_vectors = tf.reshape(att_adj_vectors, [
adjacency_matrix_shape[0], adjacency_matrix_shape[1],
adjacency_matrix_shape[2], depth
])
return att_adj_vectors
def graph_attention(q,
k,
v,
bias,
dropout_rate=0.0,
image_shapes=None,
name=None,
make_image_summary=True,
save_weights_to=None,
dropout_broadcast_dims=None,
adjacency_matrix=None,
num_edge_types=5):
"""graph attention.
Args:
q: a Tensor with shape [batch, heads, length_q, depth_k]
k: a Tensor with shape [batch, heads, length_kv, depth_k]
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
make_image_summary: True if you want an image summary.
save_weights_to: an optional dictionary to capture attention weights
for vizualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
adjacency_matrix: optional matrix of [batch, length, length] ids indicating
edge type
num_edge_types: an int indicating number of edge types
Returns:
A Tensor of shape [batch, length, depth(q)]
"""
with tf.variable_scope(
name, default_name="dot_product_attention", values=[q, k, v]) as scope:
# [batch, num_heads, query_length, memory_length]
logits = tf.matmul(q, k, transpose_b=True)
if adjacency_matrix is not None:
key_head_depth = common_layers.shape_list(q)[-1]
adjacency_vectors = make_edge_vectors(
adjacency_matrix,
num_edge_types,
key_head_depth,
name=name)
# transposing q to be [batch, length_q, heads, depth_k]
# to allow for matmul with [batch, length_q, length_q, depth_k]
q_t = tf.transpose(q, [0, 2, 1, 3])
adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True)
logits += tf.transpose(adj_logits, [0, 2, 1, 3])
# [batch, depth, num_nodes, num_nodes]
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if save_weights_to is not None:
save_weights_to[scope.name] = weights
# dropping out the attention links for each of the heads
weights = common_layers.dropout_with_broadcast_dims(
weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
if common_layers.should_generate_summaries() and make_image_summary:
common_attention.attention_image_summary(weights, image_shapes)
return tf.matmul(weights, v)
def _compute_edge_transforms(node_states,
depth,
num_transforms,
name="transform"):
"""Helper function that computes transformation for keys and values.
Let B be the number of batches.
Let N be the number of nodes in the graph.
Let D be the size of the node hidden states.
Let K be the size of the attention keys/queries (total_key_depth).
Let V be the size of the attention values (total_value_depth).
Let T be the total number of transforms (num_transforms).
Computes the transforms for keys or values for attention.
* For each node N_j and edge type t, a key K_jt of size K is computed. When an
edge of type t goes from node N_j to any other node, K_jt is the key that is
in the attention process.
* For each node N_j and edge type t, a value V_jt of size V is computed. When
an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.
Args:
node_states: A tensor of shape [B, L, D]
depth: An integer (K or V)
num_transforms: An integer (T),
name: A name for the function
Returns:
x: A The attention keys or values for each node and edge type
(shape [B, N*T, K or V])
"""
node_shapes = common_layers.shape_list(node_states)
x = common_layers.dense(
node_states,
depth * num_transforms,
use_bias=False,
name=name)
batch = node_shapes[0] # B.
length = node_shapes[1] # N.
# Making the fourth dimension explicit by separating the vectors of size
# K*T (in k) and V*T (in v) into two-dimensional matrices with shape [K, T]
# (in k) and [V, T] in v.
#
x = tf.reshape(x, [batch, length, num_transforms, depth])
# Flatten out the fourth dimension.
x = tf.reshape(x, [batch, length * num_transforms, depth])
return x
def compute_mpnn_qkv(node_states,
total_key_depth,
total_value_depth,
num_transforms):
"""Computes query, key and value for edge matrices.
Let B be the number of batches.
Let N be the number of nodes in the graph.
Let D be the size of the node hidden states.
Let K be the size of the attention keys/queries (total_key_depth).
Let V be the size of the attention values (total_value_depth).
Let T be the total number of transforms (num_transforms).
Computes the queries, keys, and values for attention.
* For each node N_i in the graph, a query Q_i of size K is computed. This
query is used to determine the relative weights to give to each of the
node's incoming edges.
* For each node N_j and edge type t, a key K_jt of size K is computed. When an
edge of type t goes from node N_j to any other node, K_jt is the key that is
in the attention process.
* For each node N_j and edge type t, a value V_jt of size V is computed. When
an edge of type t goes from node N_j to node N_i, Attention(Q_i, K_jt)
produces a weight w_ijt. The message sent along this edge is w_ijt * V_jt.
Args:
node_states: A Tensor with shape [B, N, D].
total_key_depth: an integer (K).
total_value_depth: an integer (V).
num_transforms: a integer specifying number of transforms (T). This is
typically the number of edge types.
Returns:
q: The attention queries for each destination node (shape [B, N, K]).
k: The attention keys for each node and edge type (shape [B, N*T, K]).
v: The attention values for each node and edge type (shape [B, N*T, V]).
"""
# node_states is initially a tensor with shape [B, N, D]. The call to dense
# creates a D x K kernel that serves as a fully-connected layer.
#
# For each possible batch b and node n in the first two dimensions of
# node_states, the corresponding size-D vector (the third dimension of
# node_states) is the hidden state for node n in batch b. Each of these size-D
# vectors is multiplied by the kernel to produce an attention query of size K.
# The result is a tensor of size [B, N, K] containing the attention queries
# for each node in each batch.
q = common_layers.dense(
node_states, total_key_depth, use_bias=False, name="q_mpnn")
# Creates the attention keys in a manner similar to the process of creating
# the attention queries. One key is created for each type of outgoing edge the
# corresponding node might have, meaning k will have shape [B, N, K*T].
k = _compute_edge_transforms(node_states,
total_key_depth,
num_transforms,
name="k_mpnn")
v = _compute_edge_transforms(node_states,
total_value_depth,
num_transforms,
name="v_mpnn")
return q, k, v
def sparse_message_pass_batched(node_states,
adjacency_matrices,
num_edge_types,
hidden_size,
use_bias=True,
average_aggregation=False,
name="sparse_ggnn_batched"):
"""Identical to sparse_ggnn except that each input has a batch dimension.
B = The batch size.
N = The number of nodes in each batch.
H = The size of the hidden states.
T = The number of edge types.
Args:
node_states: Initial states of each node in the graph. Shape: [B, N, H]
adjacency_matrices: Adjacency matrices of directed edges for each edge
type and batch. Shape: [B, N, N, T] (sparse).
num_edge_types: The number of edge types. T.
hidden_size: The size of the hidden layer. H.
use_bias: Whether to use bias in the hidden layer.
average_aggregation: How to aggregate the incoming node messages. If
average_aggregation is true, the messages are averaged. If it is false,
they are summed.
name: (optional) The scope within which tf variables should be created.
Returns:
The result of one round of message-passing of shape [B, N, H].
"""
b, n = tf.shape(node_states)[0], tf.shape(node_states)[1]
# Flatten the batch dimension of the node states.
node_states = tf.reshape(node_states, [b*n, hidden_size])
# Flatten the batch dimension of the adjacency matrices.
indices = adjacency_matrices.indices
new_index2 = indices[:, 3] # The edge type dimension.
# Offset N x N adjacency matrix by the batch number in which it appears.
new_index0 = indices[:, 1] + indices[:, 0] * tf.cast(n, tf.int64)
new_index1 = indices[:, 2] + indices[:, 0] * tf.cast(n, tf.int64)
# Combine these indices as triples.
new_indices = tf.stack([new_index0, new_index1, new_index2], axis=1)
# Build the new sparse matrix.
new_shape = [tf.cast(b*n, tf.int64), tf.cast(b*n, tf.int64), num_edge_types]
adjacency_matrices = tf.SparseTensor(indices=new_indices,
values=adjacency_matrices.values,
dense_shape=new_shape)
# Run a message-passing step and return the result with the batch dimension.
node_states = sparse_message_pass(
node_states,
adjacency_matrices,
num_edge_types,
hidden_size,
use_bias=use_bias,
average_aggregation=average_aggregation,
name=name)
return tf.reshape(node_states, [b, n, hidden_size])
def sparse_message_pass(node_states,
adjacency_matrices,
num_edge_types,
hidden_size,
use_bias=True,
average_aggregation=False,
name="sparse_ggnn"):
"""One message-passing step for a GNN with a sparse adjacency matrix.
Implements equation 2 (the message passing step) in
[Li et al. 2015](https://arxiv.org/abs/1511.05493).
N = The number of nodes in each batch.
H = The size of the hidden states.
T = The number of edge types.
Args:
node_states: Initial states of each node in the graph. Shape is [N, H].
adjacency_matrices: Adjacency matrix of directed edges for each edge
type. Shape is [N, N, T] (sparse tensor).
num_edge_types: The number of edge types. T.
hidden_size: The size of the hidden state. H.
use_bias: Whether to use bias in the hidden layer.
average_aggregation: How to aggregate the incoming node messages. If
average_aggregation is true, the messages are averaged. If it is false,
they are summed.
name: (optional) The scope within which tf variables should be created.
Returns:
The result of one step of Gated Graph Neural Network (GGNN) message passing.
Shape: [N, H]
"""
n = tf.shape(node_states)[0]
t = num_edge_types
incoming_edges_per_type = tf.sparse_reduce_sum(adjacency_matrices, axis=1)
# Convert the adjacency matrix into shape [T, N, N] - one [N, N] adjacency
# matrix for each edge type. Since sparse tensor multiplication only supports
# two-dimensional tensors, we actually convert the adjacency matrix into a
# [T * N, N] tensor.
adjacency_matrices = tf.sparse_transpose(adjacency_matrices, [2, 0, 1])
adjacency_matrices = tf.sparse_reshape(adjacency_matrices, [t * n, n])
# Multiply the adjacency matrix by the node states, producing a [T * N, H]
# tensor. For each (edge type, node) pair, this tensor stores the sum of
# the hidden states of the node's neighbors over incoming edges of that type.
messages = tf.sparse_tensor_dense_matmul(adjacency_matrices, node_states)
# Rearrange this tensor to have shape [N, T * H]. The incoming states of each
# nodes neighbors are summed by edge type and then concatenated together into
# a single T * H vector.
messages = tf.reshape(messages, [t, n, hidden_size])
messages = tf.transpose(messages, [1, 0, 2])
messages = tf.reshape(messages, [n, t * hidden_size])
# Run each of those T * H vectors through a linear layer that produces
# a vector of size H. This process is equivalent to running each H-sized
# vector through a separate linear layer for each edge type and then adding
# the results together.
#
# Note that, earlier on, we added together all of the states of neighbors
# that were connected by edges of the same edge type. Since addition and
# multiplying by a linear layer are commutative, this process was equivalent
# to running each incoming edge through a linear layer separately and then
# adding everything at the end.
with tf.variable_scope(name, default_name="sparse_ggnn"):
final_node_states = common_layers.dense(
messages, hidden_size, use_bias=False)
# Multiply the bias by for each edge type by the number of incoming nodes
# of that edge type.
if use_bias:
bias = tf.get_variable("bias", initializer=tf.zeros([t, hidden_size]))
final_node_states += tf.matmul(incoming_edges_per_type, bias)
if average_aggregation:
incoming_edges = tf.reduce_sum(incoming_edges_per_type, -1, keepdims=True)
incoming_edges = tf.tile(incoming_edges, [1, hidden_size])
final_node_states /= incoming_edges + 1e-7
return tf.reshape(final_node_states, [n, hidden_size])
def multihead_mpnn_attention(node_states,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
adjacency_matrix=None,
num_edge_types=5,
num_transforms=None,
use_weighted_sum=False,
name="mpnn_attention"):
"""Multihead scaled-dot-product attention with input/output transformations.
Let B be the number of batches.
Let N be the number of nodes in the graph.
Let D be the size of the node hidden states.
Let K be the size of the attention keys/queries (total_key_depth).
Let V be the size of the attention values (total_value_depth).
Let O be the size of the attention output (output_depth).
Let H be the number of heads (num_heads).
Let T be the total number of transforms (num_transforms).
The key and value depths are split across all of the heads. For example, if
the key depth is 6 and there are three heads, then the key for each head has
depth 2.
Args:
node_states: A Tensor with shape [B, N, D]
total_key_depth: An integer (K).
total_value_depth: An integer (V).
output_depth: An integer (O).
num_heads: An integer (H).
adjacency_matrix: An Tensor of ints with shape [B, T, N, N]. If there is an
edge from node j to node i in batch b, then adjacency_matrix[b, i, j]
contains the type of that edge as an integer. Otherwise, it contains 0.
num_edge_types: An integer indicating number of edge types.
num_transforms: An integer indicating number of transforms (T). If None,
then num_transforms will be equal to num_edge_types.
use_weighted_sum: If False, will only use a single transform per edge type.
Otherwise, use a learned weighted sum of transforms per edge type.
name: A string.
Returns:
The result of the attention transformation. The output shape is [B, N, O].
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
with tf.variable_scope(
name, default_name="multihead_mpnn_attention", values=[node_states]):
# If not explicitly set, use num_transforms set to num_edge_types.
num_transforms = (
num_edge_types if num_transforms is None else num_transforms)
# Create the query for each node's incoming edges.
# Create the keys/values for each node for each possible outgoing edge type.
q, k, v = compute_mpnn_qkv(
node_states,
total_key_depth,
total_value_depth,
num_transforms)
q_shape = tf.shape(q) # As above, q_shape is [B, N, K].
# Divides each query/key/value into separate heads. Specifically, the
# query/key/value for each (batch, node) pair (i.e., the third dimensions
# of q, k, and v) are broken into H separate pieces. These pieces are used
# as the separate attention heads. The resulting tensors have shape
# [B, H, N, ?/H], where ? = K, K*T or V*T as appropriate.
q = common_attention.split_heads(q, num_heads) # Shape [B, H, N, K/H].
k = common_attention.split_heads(k, num_heads) # Shape [B, H, N, K*T/H].
v = common_attention.split_heads(v, num_heads) # Shape [B, H, N, V*T/H].
key_depth_per_head = total_key_depth // num_heads
# Ensures that the logits don't have too large of a magnitude.
q *= key_depth_per_head**-0.5
# Rearrange the dimensions so that the head is first. This will make
# subsequent steps easier (we loop over the head).
q = tf.transpose(q, [1, 0, 2, 3]) # Shape [H, B, N, K/H].
k = tf.transpose(k, [1, 0, 2, 3]) # Shape [H, B, N, K*T/H].
v = tf.transpose(v, [1, 0, 2, 3]) # Shape [H, B, N, V*T/H].
# Split the keys and values into separate per-edge-type keys and values.
k = tf.reshape(k, [
num_heads, q_shape[0], q_shape[1], num_transforms,
total_key_depth // num_heads
]) # Shape [H, B, N, T, K/H].
k = tf.transpose(k, [0, 1, 3, 2, 4]) # Shape [H, B, T, N, K/H].
v = tf.reshape(v, [
num_heads, q_shape[0], q_shape[1], num_transforms,
total_value_depth // num_heads
]) # Shape [H, B, N, T, V/H].
v = tf.transpose(v, [0, 1, 3, 2, 4]) # Shape [H, B, T, N, V/H].
# Perform attention for each head and combine the results into a list.
# head_outputs stores a list of tensors, each with shape [1, B, N, V/H].
# The last dimension contains the values computed for each attention head.
# Each value was determined by computing attention over all of the
# incoming edges for node n, weighting the incoming values accordingly,
# and adding those weighted values together.
head_outputs = []
for head_id in range(num_heads):
output = dot_product_mpnn_attention(
q[head_id],
k[head_id],
v[head_id],
adjacency_matrix,
num_edge_types,
num_transforms=num_transforms,
use_weighted_sum=use_weighted_sum)
# Store this result in the list of attention results for each head.
# The call to expand_dims gives output shape [1, B, N, V/H], which will
# come in handy when we combine the heads together.
head_outputs.append(tf.expand_dims(output, axis=0))
# Combine the heads together into one tensor and rearrange the dimensions.
x = tf.concat(head_outputs, axis=0) # Shape [H, B, N, V/H].
x = tf.transpose(x, [1, 0, 2, 3]) # Shape [B, H, N, V/H].
# Concatenate the values produced by each head together into one vector.
x = common_attention.combine_heads(x) # Shape [B, N, V].
# A fully-connected linear layer to convert from the value vectors of size V
# to output vectors of length O (the appropriate output length).
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform")
return x
def dot_product_mpnn_attention(q,
k,
v,
adjacency_matrix,
num_edge_types,
num_transforms=None,
use_weighted_sum=False,
name=None):
"""Dot product attention with edge vectors.
Let B be the number of batches.
Let N be the number of nodes in the graph.
Let K be the size of the attention keys/queries.
Let V be the size of the attention values.
Let T be the total number of transforms (num_transforms).
Args:
q: The query Tensor of shape [B, N, K].
k: The key Tensor of shape [B, T, N, K].
v: The value Tensor of shape [B, T, N, V].
adjacency_matrix: A Tensor of shape [B, N, N, T]. An entry at
indices b, i, j, k is the indicator of the edge
from node j to node i in batch b. A standard adjacency matrix will only
have one edge type while a mutigraph will have multiple edge types.
num_edge_types: An integer specifying number of edge types.
num_transforms: An integer indicating number of transforms (T). If None,
then num_transforms will be equal to num_edge_types.
use_weighted_sum: If False, will only use a single transform per edge type.
Otherwise, use a learned weighted sum of transforms per edge type.
name: A string.
Returns:
A Tensor of shape [B, N, V] storing the result of computing attention
weights using the queries and keys and combining the values according to
those weights.
Raises:
ValueError: if num_transforms doesn't equal num_edge_types and not using
weighted sum.
"""
with tf.variable_scope(
name,
default_name="dot_product_mpnn_attention",
values=[q, k, v, adjacency_matrix, num_edge_types]):
# If not explicitly set, use num_transforms set to num_edge_types.
num_transforms = (
num_edge_types if num_transforms is None else num_transforms)
if not use_weighted_sum and num_transforms != num_edge_types:
raise ValueError("num_transforms must equal num_edge_types unless "
"use_weighted_sum is True")
# Computes the raw dot-product attention values between each query and
# the corresponding keys it needs to consider.
#
# This operation takes the dot product of (the query for
# each node) and (the key for each node for each possible edge type),
# creating an N x N matrix for each edge type. The entry at index (i, j)
# is the dot-product for the edge from node i to node j of the appropriate
# type. These dot products will eventually become attention weights
# specifying how much node i weights an edge of that type coming from node
# j.
all_edge_logits = tf.matmul(
tf.tile(tf.expand_dims(q, axis=1), [1, num_edge_types, 1, 1]),
k,
transpose_b=True)
# The adjacency matrix assumes there is only one directed edge (i <- j) for
# each pair of nodes. If such an edge exists, it contains the integer
# type of that edge at position (i, j) of the adjacency matrix.
#
# Construct edge_vectors of shape [B, N, N, T].
if use_weighted_sum:
# Use dense representation for edge vectors.
edge_vectors = make_edge_vectors(
adjacency_matrix,
num_edge_types,
num_transforms)
else:
# Generate one-hot vectors based on edge types.
# If there is an edge from node j to node i of type t, then index t of the
# last dimension is 1 for entry (i, j) of the second and third dimensions.
edge_vectors = tf.one_hot(adjacency_matrix, num_transforms)
# Rearranging the dimensions to match the shape of all_edge_logits.
edge_vectors = tf.transpose(edge_vectors, [0, 3, 1, 2])
# Element-wise multiplies all_edge_logits and edge_vectors.
#
# In other words: all_edge_logits contains N x N matrices of query-key
# products. This element-wise multiplication zeroes out entries that do not
# correspond to actual edges in the graph of the appropriate edge type.
# all_edge_logits retains shape [B, T, N, N].
all_edge_logits *= edge_vectors
# Since there can only be one edge from node A to node B, we can collapse
# the T different adjacency matrices containing key-query pairs into one
# adjacency matrix. logits is [B, N, N].
# TODO(dbieber): Use a reshape instead of reduce sum to attend over all
# edges instead of over all neighboring nodes to handle the multigraph case.
logits = tf.reduce_sum(all_edge_logits, axis=1)
# For pairs of nodes with no edges between them, add a large negative bias
# to each location without an edge so that the softmax of entries with the
# value 0 become a small negative number instead.
bias = 0
bias = tf.to_float(tf.equal(
tf.reduce_sum(adjacency_matrix, axis=-1), 0)) * -1e9
logits += bias
# Turn the raw key-query products into a probability distribution (or,
# in terms of attention, weights). The softmax is computed across the
# last dimension of logits.
compatibility = tf.nn.softmax(logits) # Shape [B, N, N].
# Computes a summary showing the attention matrix as an image. Does not do
# any work toward actually performing attention.
common_attention.attention_image_summary(
tf.expand_dims(compatibility, axis=1), None)
# Repeats the attention matrix T times for each batch, producing
# a tensor with shape [B, T, N, N] where the [N, N] component is T
# repeats of the values found in compatibility.
edge_compatibility = tf.tile(
tf.expand_dims(compatibility, axis=1), [1, num_edge_types, 1, 1])
# Zeroes out the entries in edge_compatibility that do not correspond to
# actual edges.
edge_compatibility *= edge_vectors # Shape [B, T, N, N].
output = compute_values(edge_compatibility, v)
return output
def ggnn_fast_dense(node_states,
adjacency_matrix,
num_edge_types,
total_value_depth,
name=None):
"""ggnn version of the MPNN from Gilmer et al.
Let B be the number of batches.
Let D be the size of the node hidden states.
Let K be the size of the attention keys/queries.
Let V be the size of the output of the ggnn.
Let T be the number of transforms / edge types.
Args:
node_states: The value Tensor of shape [B, T, N, D].
adjacency_matrix: A Tensor of shape [B, N, N, T]. An entry at
indices b, i, j, k is the indicator of the edge from node j to node i in
batch b. A standard adjacency matrix will only have values of one, while a
mutigraph may have larger integer values.
num_edge_types: An integer specifying number of edge types.
total_value_depth: An integer (V)
name: A string.
Returns:
A Tensor of shape [B, N, V] storing the result of computing attention
weights using the queries and keys and combining the values according to
those weights.
Raises:
ValueError: if num_transforms doesn't equal num_edge_types and not using
weighted sum.
"""
# between the same nodes (with only one edge of each type. adjacency_matrix
# will need to be converted to shape [B, T, N, N].
with tf.variable_scope(
name,
default_name="ggnn_fast_dense",
values=[node_states, adjacency_matrix, num_edge_types]):
nodes_shape = common_layers.shape_list(node_states)
v = _compute_edge_transforms(node_states,
total_value_depth,
num_edge_types,
name="v_mpnn")
v = tf.reshape(v, [nodes_shape[0], nodes_shape[1], num_edge_types,
total_value_depth
]) # Shape [B, N, T, V].
v = tf.transpose(v, [0, 2, 1, 3]) # Shape [B, T, N, V].
# Rearranging the dimensions to match the shape of all_edge_logits.
edge_vectors = tf.transpose(adjacency_matrix, [0, 3, 1, 2])
output = compute_values(edge_vectors, v)
return output
def compute_values(edge_compatibility, v):
"""Compute values. If edge compatibilities is just adjacency, we get ggnn.
Args:
edge_compatibility: A tensor of shape [batch, num_transforms, length, depth]
v: A tensor of shape [batch, num_transforms, length, depth]
Returns:
output: A [batch, length, depth] tensor
"""
# Computes the incoming value vectors for each node by weighting them
# according to the attention weights. These values are still segregated by
# edge type.
# Shape = [B, T, N, V].
all_edge_values = tf.matmul(tf.to_float(edge_compatibility), v)
# Combines the weighted value vectors together across edge types into a
# single N x V matrix for each batch.
output = tf.reduce_sum(all_edge_values, axis=1) # Shape [B, N, V].
return output
def precompute_edge_matrices(adjacency, hparams):
"""Precompute the a_in and a_out tensors.
(we don't want to add to the graph everytime _fprop is called)
Args:
adjacency: placeholder of real valued vectors of shape [B, L, L, E]
hparams: HParams object
Returns:
edge_matrices: [batch, L * D, L * D] the dense matrix for message passing
viewed as a block matrix (L,L) blocks of size (D,D). Each plot is a function
of the edge vector of the adjacency matrix at that spot.
"""
batch_size, num_nodes, _, edge_dim = common_layers.shape_list(adjacency)
# build the edge_network for incoming edges
with tf.variable_scope("edge_network"):
x = tf.reshape(
adjacency, [batch_size * num_nodes * num_nodes, edge_dim],
name="adj_reshape_in")
for ip_layer in range(hparams.edge_network_layers):
name = "edge_network_layer_%d"%ip_layer
x = tf.layers.dense(common_layers.layer_preprocess(x, hparams),
hparams.edge_network_hidden_size,
activation=tf.nn.relu,
name=name)
x = tf.layers.dense(common_layers.layer_preprocess(x, hparams),
hparams.hidden_size**2,
activation=None,
name="edge_network_output")
# x = [batch * l * l, d *d]
edge_matrices_flat = tf.reshape(x, [batch_size, num_nodes,
num_nodes, hparams.hidden_size,
hparams.hidden_size])
# reshape to [batch, l * d, l *d]
edge_matrices = tf.reshape(
tf.transpose(edge_matrices_flat, [0, 1, 3, 2, 4]), [
-1, num_nodes * hparams.hidden_size,
num_nodes * hparams.hidden_size
],
name="edge_matrices")
return edge_matrices
def dense_message_pass(node_states, edge_matrices):
"""Computes a_t from h_{t-1}, see bottom of page 3 in the paper.
Args:
node_states: [B, L, D] tensor (h_{t-1})
edge_matrices (tf.float32): [B, L*D, L*D]
Returns:
messages (tf.float32): [B, L, D] For each pair
of nodes in the graph a message is sent along both the incoming and
outgoing edge.
"""
batch_size, num_nodes, node_dim = common_layers.shape_list(node_states)
# Stack the nodes as a big column vector.
h_flat = tf.reshape(
node_states, [batch_size, num_nodes * node_dim, 1], name="h_flat")
messages = tf.reshape(
tf.matmul(edge_matrices, h_flat), [batch_size * num_nodes, node_dim],
name="messages_matmul")
message_bias = tf.get_variable("message_bias", shape=node_dim)
messages = messages + message_bias
messages = tf.reshape(messages, [batch_size, num_nodes, node_dim])
return messages
================================================
FILE: tensor2tensor/layers/modalities.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modalities, which specify a feature's domain.
T2TModel applies a default transformation to each feature according to its
modality. Override them by specifying a model's
hparams.{bottom,loss,top,weights_fn}.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_audio
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import common_video
from tensor2tensor.layers import discretization
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow_probability as tfp
class ModalityType(object):
"""Types of modalities."""
AUDIO = "audio"
AUDIO_SPECTRAL = "audio_spectral"
CLASS_LABEL = "class_label"
CTC_SYMBOL = "ctc_symbol" # symbol with CTC loss
GENERIC_L2_LOSS = "generic_l2" # identity modality with L2 loss
IDENTITY = "identity" # identity top and bottom
IDENTITY_SYMBOL = "identity_symbol" # symbol with identity top and bottom
IMAGE = "image"
# images using channel compression for generation
IMAGE_CHANNEL_BOTTOM_IDENTITY = "image_channel_bottom_identity"
# images using channel compression for generation
IMAGE_CHANNEL_COMPRESS = "image_channel_compress"
IMAGE_CHANNEL_EMBEDDINGS_BOTTOM = "image_channel_embeddings_bottom"
MULTI_LABEL = "multi_label"
ONE_HOT_CLASS_LABEL = "one_hot_class_label"
REAL = "real" # real vectors
REAL_L2_LOSS = "real_l2" # real vectors with L2 as loss
# real vectors with log Poisson regression loss
REAL_LOG_POISSON_LOSS = "real_log_poisson"
SIGMOID_CLASS_LABEL = "sigmoid_class_label" # sigmoid cross-entropy loss
# sigmoid cross-entropy applied on max-pooling over timesteps
SIGMOID_MAX_POOLING_CLASS_LABEL = "sigmoid_max_pooling_class_label"
# softmax cross-entropy applied on average-pooling over timesteps
SOFTMAX_AVERAGE_POOLING_CLASS_LABEL = "softmax_average_pooling_class_label"
# softmax cross-entropy applied on last-timestep encoding
SOFTMAX_LAST_TIMESTEP_CLASS_LABEL = "softmax_last_timestep_class_label"
# softmax cross-entropy applied on max-pooling over timesteps
SOFTMAX_MAX_POOLING_CLASS_LABEL = "softmax_max_pooling_class_label"
SPEECH_RECOGNITION = "speech_recognition"
SYMBOL = "symbol"
SYMBOL_WEIGHTS_ALL = "symbol_weights_all" # symbol for features w/o 0-padding
SYMBOL_ONE_HOT = "symbol_one_hot" # symbol with one hot as embeddings
VIDEO = "video"
VIDEO_BITWISE = "video_bitwise" # video where bottom embeds pixels bitwise
VIDEO_IDENTITY = "video_identity" # video with identity top and bottom
VIDEO_L1 = "video_l1" # video with L2 loss
VIDEO_L2 = "video_l2" # video with L1 loss
# video with L1 loss and raw input (sequences of frames)
VIDEO_L1_RAW = "video_l1_raw"
# video with L2 loss and raw input (sequences of frames)
VIDEO_L2_RAW = "video_l2_raw"
# video with pixel noise on input during training
VIDEO_PIXEL_NOISE = "video_pixel_noise"
@staticmethod
def get_choices():
return [
ModalityType.AUDIO,
ModalityType.AUDIO_SPECTRAL,
ModalityType.CLASS_LABEL,
ModalityType.CTC_SYMBOL,
ModalityType.GENERIC_L2_LOSS,
ModalityType.IDENTITY,
ModalityType.IDENTITY_SYMBOL,
ModalityType.IMAGE,
ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
ModalityType.IMAGE_CHANNEL_COMPRESS,
ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM,
ModalityType.MULTI_LABEL,
ModalityType.ONE_HOT_CLASS_LABEL,
ModalityType.REAL,
ModalityType.REAL_L2_LOSS,
ModalityType.REAL_LOG_POISSON_LOSS,
ModalityType.SIGMOID_CLASS_LABEL,
ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL,
ModalityType.SPEECH_RECOGNITION,
ModalityType.SYMBOL,
ModalityType.SYMBOL_ONE_HOT,
ModalityType.SYMBOL_WEIGHTS_ALL,
ModalityType.VIDEO,
ModalityType.VIDEO_BITWISE,
ModalityType.VIDEO_IDENTITY,
ModalityType.VIDEO_L1,
ModalityType.VIDEO_L2,
ModalityType.VIDEO_L1_RAW,
ModalityType.VIDEO_L2_RAW,
ModalityType.VIDEO_PIXEL_NOISE,
]
# Bottom transformations, applied to all features
def audio_bottom(x, model_hparams, vocab_size):
"""Transform input from data space to model space.
Args:
x: A Tensor with shape [batch, ...]
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
body_input: A Tensor with shape [batch, ?, ?,
model_hparams.hidden_size].
"""
del vocab_size # unused arg
inputs = x
with tf.variable_scope("audio_modality"):
# TODO(aidangomez): Will need to sort out a better audio pipeline
def xnet_resblock(x, filters, res_relu, name):
"""Xception block."""
with tf.variable_scope(name):
# Typically audio samples are >100k samples in length and have a width
# of 2 or 4. Mono audio has a single channel while stereo has 2.
y = common_layers.separable_conv_block(
x,
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
first_relu=True,
padding="SAME",
force2d=True,
name="sep_conv_block")
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2))
return y + common_layers.conv_block(
x,
filters, [((1, 1), (1, 1))],
padding="SAME",
strides=(2, 2),
first_relu=res_relu,
force2d=True,
name="res_conv0")
x = tf.to_float(inputs) / 255.
x.set_shape([None, None, None, 1])
for i in range(model_hparams.audio_compression):
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
return xnet_resblock(x,
model_hparams.hidden_size,
False,
"compress_block_final")
def audio_spectral_bottom(x, model_hparams, vocab_size):
"""Transform input from data space to model space.
Args:
x: A Tensor with shape [batch, ...]
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
body_input: A Tensor with shape [batch, ?, ?,
model_hparams.hidden_size].
"""
del vocab_size # unused arg
inputs = x
with tf.variable_scope("audio_spectral_modality"):
# TODO(aidangomez): Will need to sort out a better audio pipeline
def xnet_resblock(x, filters, res_relu, name):
"""Xception-like block."""
with tf.variable_scope(name):
# We only stride along the length dimension to preserve the spectral
# bins (which are tiny in dimensionality relative to length)
y = common_layers.separable_conv_block(
x,
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
first_relu=True,
padding="SAME",
force2d=True,
name="sep_conv_block")
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
return y + common_layers.conv_block(
x,
filters, [((1, 1), (1, 1))],
padding="SAME",
strides=(2, 1),
first_relu=res_relu,
force2d=True,
name="res_conv0")
# Bitcast back from int32
x = tf.bitcast(inputs, tf.float32)
x.set_shape([None, None, None, 1])
for i in range(model_hparams.audio_compression):
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
return xnet_resblock(x,
model_hparams.hidden_size,
False,
"compress_block_final")
def class_label_bottom(x, model_hparams, vocab_size):
with tf.variable_scope("class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
multiplier = 1.0
if model_hparams.multiply_embedding_mode == "sqrt_depth":
multiplier = model_hparams.hidden_size**0.5
return common_layers.embedding(x,
vocab_size,
model_hparams.hidden_size,
multiplier=multiplier)
def class_label_targets_bottom(x, model_hparams, vocab_size):
with tf.variable_scope("class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
return tf.zeros([common_layers.shape_list(x)[0],
1,
1,
model_hparams.hidden_size])
def identity_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
return tf.to_float(x)
def image_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
with tf.variable_scope("image_modality"):
if not tf.executing_eagerly():
tf.summary.image(
"inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2)
return tf.to_float(x)
def image_targets_bottom(x, model_hparams, vocab_size):
"""Bottom transformation for target images."""
pixel_embedding_size = 64
inputs = x
with tf.variable_scope("image_modality"):
if not tf.executing_eagerly():
tf.summary.image(
"targets_bottom",
common_layers.tpu_safe_image_summary(inputs),
max_outputs=1)
inputs_shape = common_layers.shape_list(inputs)
if len(inputs_shape) != 4:
raise ValueError("Assuming images given as int tensors in the format "
"[batch, height, width, channels] (256 values).")
# We embed each of 256=vocab_size possible pixel values.
embedding_var = tf.get_variable(
"pixel_embedding",
[vocab_size, pixel_embedding_size])
hot_inputs = tf.one_hot(tf.to_int32(inputs), vocab_size)
hot_inputs = tf.reshape(hot_inputs, [-1, vocab_size])
embedded = tf.matmul(hot_inputs, embedding_var)
# Let's now merge all channels that were embedded into a single vector.
merged_size = pixel_embedding_size * inputs_shape[3]
embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size])
merged = tf.layers.dense(
embedded,
model_hparams.hidden_size,
name="merge_pixel_embedded_channels")
return merged
def _image_channel_compress_bottom(inputs, model_hparams, name="bottom"):
"""Compresses channel-wise input pixels into whole pixel representions.
Perform conversion of RGB pixel values to a real number in the range -1 to
1. This combines pixel channels to form a representation of shape
[img_len, img_len].
Args:
inputs: Tensor representing RGB pixel intensities as integers, of shape
[batch, img_len, img_len, channels].
model_hparams: HParams, model hyperparmeters.
name: string, scope.
Returns:
body_input: Tensor of shape
[batch, img_len, img_len, model_hparams.hidden_size].
"""
num_channels = 3
with tf.variable_scope(name):
inputs = tf.to_float(inputs)
hp = model_hparams
if hp.mode != tf_estimator.ModeKeys.PREDICT:
tf.summary.image(
"inputs",
common_layers.tpu_safe_image_summary(inputs),
max_outputs=2)
inputs = common_layers.convert_rgb_to_symmetric_real(inputs)
# Reshape inputs to apply convolutions across [img_len, img_len*channels].
inputs_shape = common_layers.shape_list(inputs)
inputs = tf.reshape(
inputs, [-1, inputs_shape[1], inputs_shape[2] * inputs_shape[3], 1])
# Compress RGB intensities for each pixel using a convolution.
outputs = tf.layers.conv2d(
inputs,
model_hparams.hidden_size,
kernel_size=(1, num_channels),
padding="VALID",
strides=(1, num_channels),
activation=tf.nn.relu,
name="conv_input")
return outputs
def image_channel_compress_bottom(x, model_hparams, vocab_size):
del vocab_size # unused arg
return _image_channel_compress_bottom(x, model_hparams, "input_bottom")
def image_channel_compress_targets_bottom(x, model_hparams, vocab_size):
del vocab_size # unused arg
return _image_channel_compress_bottom(x, model_hparams, "output_bottom")
def image_channel_embeddings_bottom(x, model_hparams, vocab_size):
"""Bottom transformation for image targets."""
del vocab_size # unused arg
inputs = tf.to_int32(x)
io_depth = model_hparams.num_channels
tshape = common_layers.shape_list(inputs)
hidden_size = model_hparams.hidden_size
target_embeddings = cia.get_channel_embeddings(
io_depth, inputs, hidden_size, "input_bottom")
return tf.reshape(target_embeddings,
[tshape[0], tshape[1], tshape[2] * io_depth, hidden_size])
def make_targets_bottom(bottom):
def targets_bottom(x, model_hparams, vocab_size):
with tf.variable_scope("targets_bottom"):
return bottom(x, model_hparams, vocab_size)
return targets_bottom
def real_bottom(x, model_hparams, vocab_size):
del vocab_size # unused arg
with tf.variable_scope("real"):
return tf.layers.dense(
tf.to_float(x), model_hparams.hidden_size, name="bottom")
def speech_recognition_bottom(x, model_hparams, vocab_size):
"""Use batchnorm instead of CMVN and shorten the stft with strided convs.
Args:
x: float32 tensor with shape [batch_size, len, 1, freqs * channels]
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
float32 tensor with shape [batch_size, shorter_len, 1, hidden_size]
"""
del vocab_size # unused arg
inputs = x
p = model_hparams
num_mel_bins = p.audio_num_mel_bins
num_channels = 3 if p.audio_add_delta_deltas else 1
with tf.variable_scope("speech_recognition_modality"):
if p.audio_preproc_in_bottom:
# Compute filterbanks
with tf.variable_scope("fbanks"):
waveforms = tf.squeeze(inputs, [2, 3])
mel_fbanks = common_audio.compute_mel_filterbank_features(
waveforms,
sample_rate=p.audio_sample_rate,
dither=p.audio_dither,
preemphasis=p.audio_preemphasis,
frame_length=p.audio_frame_length,
frame_step=p.audio_frame_step,
lower_edge_hertz=p.audio_lower_edge_hertz,
upper_edge_hertz=p.audio_upper_edge_hertz,
num_mel_bins=p.audio_num_mel_bins,
apply_mask=True)
if p.audio_add_delta_deltas:
mel_fbanks = common_audio.add_delta_deltas(mel_fbanks)
x = tf.reshape(mel_fbanks,
common_layers.shape_list(mel_fbanks)[:2] +
[num_mel_bins, num_channels])
nonpadding_mask = 1. - common_attention.embedding_to_padding(x)
num_of_nonpadding_elements = tf.reduce_sum(
nonpadding_mask) * num_mel_bins * num_channels
# This replaces CMVN estimation on data
var_epsilon = 1e-09
mean = tf.reduce_sum(
x, axis=[1], keepdims=True) / num_of_nonpadding_elements
variance = (num_of_nonpadding_elements * mean**2. -
2. * mean * tf.reduce_sum(x, axis=[1], keepdims=True) +
tf.reduce_sum(x**2, axis=[1], keepdims=True)
) / num_of_nonpadding_elements
x = (x - mean) * tf.rsqrt(variance + var_epsilon) * tf.expand_dims(
nonpadding_mask, -1)
else:
x = inputs
# The convention is that the models are flattened along the spatial,
# dimensions, thus the speech preprocessor treats frequencies and
# channels as image colors (last axis)
x.set_shape([None, None, num_mel_bins, num_channels])
# TODO(chorowski): how to specify bottom's hparams and avoid hardcoding?
x = tf.pad(x, [[0, 0], [0, 8], [0, 0], [0, 0]])
for _ in range(2):
x = tf.layers.conv2d(
x, 128, (3, 3), (2, 2), use_bias=False)
x = common_layers.layer_norm(x)
x = tf.nn.relu(x)
xshape = common_layers.shape_list(x)
# apply a conv that will remove all frequencies and at the same time
# project the output into desired hidden_size
x = tf.pad(x, [[0, 0], [0, 2], [0, 0], [0, 0]])
x = tf.layers.conv2d(x, p.hidden_size, (3, xshape[2]), use_bias=False)
assert common_layers.shape_list(x)[2] == 1
x = common_layers.layer_norm(x)
x = tf.nn.relu(x)
return x
def get_weights(model_hparams, vocab_size, hidden_dim=None):
"""Create or get concatenated embedding or softmax variable.
Args:
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
hidden_dim: dim of the variable. Defaults to _model_hparams' hidden_size
Returns:
a list of num_shards Tensors.
"""
if hidden_dim is None:
hidden_dim = model_hparams.hidden_size
num_shards = model_hparams.symbol_modality_num_shards
shards = []
for i in range(num_shards):
shard_size = (vocab_size // num_shards) + (
1 if i < vocab_size % num_shards else 0)
var_name = "weights_%d" % i
shards.append(
tf.get_variable(
var_name, [shard_size, hidden_dim],
initializer=tf.random_normal_initializer(0.0, hidden_dim**-0.5)))
if num_shards == 1:
ret = shards[0]
else:
ret = tf.concat(shards, 0)
# Convert ret to tensor.
if not tf.executing_eagerly():
ret = common_layers.convert_gradient_to_tensor(ret)
return ret
def _symbol_bottom_simple(x, model_hparams, vocab_size, name, reuse):
"""Bottom transformation for symbols."""
with tf.variable_scope(name, reuse=reuse):
# Ensure the inputs are 3-D
if len(x.get_shape()) == 4:
x = tf.squeeze(x, axis=3)
while len(x.get_shape()) < 3:
x = tf.expand_dims(x, axis=-1)
var = get_weights(model_hparams, vocab_size)
x = common_layers.dropout_no_scaling(
x, 1.0 - model_hparams.symbol_dropout)
ret = common_layers.gather(var, x)
if model_hparams.multiply_embedding_mode == "sqrt_depth":
ret *= model_hparams.hidden_size**0.5
ret *= tf.expand_dims(
common_layers.cast_like(tf.not_equal(x, 0), ret), -1)
return ret
def symbol_bottom(x, model_hparams, vocab_size):
if (model_hparams.shared_embedding_and_softmax_weights or
model_hparams.get("shared_embedding")):
return _symbol_bottom_simple(
x, model_hparams, vocab_size, "shared", reuse=None)
return _symbol_bottom_simple(
x, model_hparams, vocab_size, "input_emb", reuse=None)
def symbol_targets_bottom(x, model_hparams, vocab_size):
"""Bottom transformation for target symbols."""
if (model_hparams.shared_embedding_and_softmax_weights or
model_hparams.get("shared_embedding")):
try:
return _symbol_bottom_simple(
x, model_hparams, vocab_size, "shared", reuse=True)
except ValueError:
# perhaps there were no inputs, and this is a new variable.
return _symbol_bottom_simple(
x, model_hparams, vocab_size, "shared", reuse=None)
else:
return _symbol_bottom_simple(
x, model_hparams, vocab_size, "target_emb", reuse=None)
def symbol_one_hot_bottom(x, model_hparams, vocab_size):
del model_hparams # unused arg
return tf.one_hot(x, vocab_size)
def video_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
common_video.gif_summary("inputs", x, max_outputs=1)
x = common_layers.standardize_images(x)
return x
def video_targets_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
common_video.gif_summary("targets", x, max_outputs=1)
x = common_layers.standardize_images(x)
return x
def video_bitwise_bottom(x, model_hparams, vocab_size):
"""Bottom transformation for embedding video bitwise."""
pixel_embedding_size = 64
inputs = x
with tf.variable_scope("video_modality_bitwise", reuse=tf.AUTO_REUSE):
common_layers.summarize_video(inputs, "bottom")
# Embed bitwise.
assert vocab_size == 256
embedded = discretization.int_to_bit_embed(inputs, 8,
pixel_embedding_size)
# Project.
return tf.layers.dense(
embedded,
model_hparams.hidden_size,
name="merge_pixel_embedded_frames")
def video_bitwise_targets_bottom(x, model_hparams, vocab_size):
"""Bottom transformation for embedding target video bitwise."""
pixel_embedding_size = 64
inputs = x
with tf.variable_scope("video_modality_bitwise", reuse=tf.AUTO_REUSE):
common_layers.summarize_video(inputs, "targets_bottom")
# Embed bitwise.
assert vocab_size == 256
embedded = discretization.int_to_bit_embed(inputs, 8,
pixel_embedding_size)
# Transpose and project.
transposed = common_layers.time_to_channels(embedded)
return tf.layers.dense(
transposed,
model_hparams.hidden_size,
name="merge_pixel_embedded_frames")
def video_identity_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
common_video.gif_summary("inputs", x, max_outputs=1)
return x
def video_identity_targets_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
common_video.gif_summary("targets", x, max_outputs=1)
return x
def video_pixel_noise_bottom(x, model_hparams, vocab_size):
"""Bottom transformation for video."""
input_noise = getattr(model_hparams, "video_modality_input_noise", 0.25)
inputs = x
if model_hparams.mode == tf_estimator.ModeKeys.TRAIN:
background = tfp.stats.percentile(inputs, 50., axis=[0, 1, 2, 3])
input_shape = common_layers.shape_list(inputs)
input_size = tf.reduce_prod(input_shape[:-1])
input_mask = tf.multinomial(
tf.log([[input_noise, 1.-input_noise]]), input_size)
input_mask = tf.reshape(tf.cast(input_mask, tf.int32),
input_shape[:-1]+[1])
inputs = inputs * input_mask + background * (1 - input_mask)
return video_bottom(inputs, model_hparams, vocab_size)
def convert_rgb_to_real(prediction, targets):
"""Convert prediction and target from rgb to real."""
prediction = tf.squeeze(prediction, axis=-1)
prediction = common_layers.convert_rgb_to_real(prediction)
targets = common_layers.convert_rgb_to_real(targets)
return prediction, targets
def video_raw_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
common_video.gif_summary("inputs", x)
return common_layers.convert_rgb_to_real(x)
def video_raw_targets_bottom(x, model_hparams, vocab_size):
del model_hparams, vocab_size # unused arg
common_video.gif_summary("targets_bottom", x)
return common_layers.convert_rgb_to_real(x)
# Loss transformations, applied to target features
def ctc_symbol_loss(top_out, targets, model_hparams, vocab_size, weight_fn):
"""Compute the CTC loss."""
del model_hparams, vocab_size # unused arg
logits = top_out
with tf.name_scope("ctc_loss", values=[logits, targets]):
# For CTC we assume targets are 1d, [batch, length, 1, 1] here.
targets_shape = targets.get_shape().as_list()
assert len(targets_shape) == 4
assert targets_shape[2] == 1
assert targets_shape[3] == 1
targets = tf.squeeze(targets, axis=[2, 3])
logits = tf.squeeze(logits, axis=[2, 3])
targets_mask = 1 - tf.to_int32(tf.equal(targets, 0))
targets_lengths = tf.reduce_sum(targets_mask, axis=1)
sparse_targets = tf.keras.backend.ctc_label_dense_to_sparse(
targets, targets_lengths)
xent = tf.nn.ctc_loss(
sparse_targets,
logits,
targets_lengths,
time_major=False,
preprocess_collapse_repeated=False,
ctc_merge_repeated=False)
weights = weight_fn(targets)
return tf.reduce_sum(xent), tf.reduce_sum(weights)
def generic_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
"""Compute loss numerator and denominator for one shard of output."""
del vocab_size # unused arg
logits = top_out
logits = common_attention.maybe_upcast(logits, hparams=model_hparams)
cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.0)
return common_layers.padded_cross_entropy(
logits,
targets,
model_hparams.label_smoothing,
cutoff=cutoff,
weights_fn=weights_fn)
def generic_l2_loss(body_output,
targets,
model_hparams,
vocab_size,
weights_fn):
del model_hparams, vocab_size, weights_fn # unused arg
loss = tf.squared_difference(body_output, tf.to_float(targets))
return tf.reduce_mean(loss), tf.constant(1.0)
def multi_label_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
"""Average loss over the labels."""
del vocab_size # unused arg
logits = top_out
num_labels = tf.shape(targets)[1]
logits = tf.tile(logits, [1, num_labels, 1, 1, 1])
xent, weights = common_layers.padded_cross_entropy(
logits,
targets,
model_hparams.label_smoothing,
weights_fn=weights_fn,
reduce_sum=False,
)
xent = tf.squeeze(xent, [2, 3])
weights = tf.squeeze(weights, [2, 3])
# average loss over all labels
loss = tf.reduce_sum(xent, axis=1)
weights = tf.reduce_sum(weights, axis=1)
loss /= (weights + 1e-8)
weights = tf.to_float(tf.greater(weights, 0.))
return tf.reduce_sum(loss*weights), tf.reduce_sum(weights)
def one_hot_class_label_loss(top_out,
targets,
model_hparams,
vocab_size,
weights_fn):
"""Apply softmax cross-entropy between outputs and targets.
Args:
top_out: logits Tensor with shape [batch, ?, ?, num_classes]
targets: one-hot encoding Tensor with shape [batch, ?, ?, num_classes]
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
weights_fn:
Returns:
loss_scale (cross-entropy), loss_denom
"""
del model_hparams, vocab_size # unused arg
loss_scale = tf.losses.softmax_cross_entropy(
onehot_labels=targets, logits=top_out)
weights = weights_fn(targets)
loss_denom = tf.reduce_sum(weights)
return loss_scale, loss_denom
def real_l2_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
del model_hparams, vocab_size # unused arg
predictions = top_out
if (len(common_layers.shape_list(top_out)) != len(
common_layers.shape_list(targets))):
predictions = tf.squeeze(top_out, axis=[-1])
with tf.name_scope("l2"):
weights = weights_fn(targets)
l2 = tf.pow(predictions - targets, 2)
return tf.reduce_sum(l2 * weights), tf.reduce_sum(weights)
def real_log_poisson_loss(top_out,
targets,
model_hparams,
vocab_size,
weights_fn):
"""Poisson loss for real."""
del model_hparams, vocab_size # unused arg
predictions = top_out
if (len(common_layers.shape_list(top_out)) != len(
common_layers.shape_list(targets))):
predictions = tf.squeeze(top_out, axis=[-1])
with tf.name_scope("log_possion"):
weights = weights_fn(targets)
lp_loss = tf.nn.log_poisson_loss(targets, predictions)
return tf.reduce_sum(lp_loss * weights), tf.reduce_sum(weights)
def sigmoid_class_label_loss(top_out,
targets,
model_hparams,
vocab_size,
weights_fn):
"""Loss for class label."""
# Expect inputs of size [batch-size, timesteps, 1, num-classes], where the
# last dimension of num-classes represents logits for binary labels
del model_hparams, vocab_size # unused arg
loss_scale = tf.losses.sigmoid_cross_entropy(
multi_class_labels=targets, logits=top_out)
weights = weights_fn(targets)
loss_denom = tf.reduce_sum(weights)
return loss_scale, loss_denom
def sigmoid_max_pooling_class_label_loss(top_out,
targets,
model_hparams,
vocab_size,
weights_fn):
"""Loss for class label."""
# Expect inputs of size [batch-size, 1, 1, num-classes], where the
# last dimension of num-classes represents logits for binary labels
del model_hparams, vocab_size # unused arg
loss_scale = tf.losses.sigmoid_cross_entropy(
multi_class_labels=targets, logits=top_out)
weights = weights_fn(targets)
loss_denom = tf.reduce_sum(weights)
return loss_scale, loss_denom
def symbol_one_hot_loss(top_out,
targets,
model_hparams,
vocab_size,
weights_fn):
del model_hparams, weights_fn # unused arg
labels = tf.one_hot(targets, vocab_size)
loss = tf.nn.softmax_cross_entropy_with_logits(
logits=top_out, labels=labels)
return tf.reduce_mean(loss), tf.constant(1.0)
def video_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
"""Compute loss numerator and denominator for one shard of output."""
del vocab_size # unused arg
logits = top_out
logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:])
targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.01)
return common_layers.padded_cross_entropy(
logits,
targets,
model_hparams.label_smoothing,
cutoff=cutoff,
weights_fn=weights_fn)
def video_identity_loss(top_out,
targets,
model_hparams,
vocab_size,
weights_fn):
"""Compute loss numerator and denominator for one shard of output."""
del vocab_size # unused arg
# TODO(nikip): Try L2 loss
logits = top_out
logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:])
targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.01)
return common_layers.padded_cross_entropy(
logits,
targets,
model_hparams.label_smoothing,
cutoff=cutoff,
weights_fn=weights_fn)
def video_l1_internal_loss(logits, targets, model_hparams):
cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.2)
return tf.nn.relu(tf.abs(logits - targets) - cutoff)
def video_l1_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
"""Compute loss numerator and denominator for one shard of output."""
del vocab_size # unused arg
logits = top_out
logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1])
targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
weights = weights_fn(targets)
# Shift targets by 0.5 so later just casting to int gives the prediction.
# So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5.
# Later (in merics or infer) this is cast to int anyway. Also, we have no
# loss beyond cutoff = 0.2 as these are already correct predictions.
targets = tf.to_float(targets) + 0.5
loss = video_l1_internal_loss(logits, targets, model_hparams)
return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)
def video_l2_internal_loss(logits, targets, model_hparams):
cutoff = getattr(model_hparams, "video_modality_loss_cutoff", 0.2)
return tf.nn.relu(
tf.squared_difference(logits, targets) - cutoff * cutoff)
def video_l2_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
"""Compute loss numerator and denominator for one shard of output."""
del vocab_size # unused arg
logits = top_out
logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1])
targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
weights = weights_fn(targets)
# Shift targets by 0.5 so later just casting to int gives the prediction.
# So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5.
# Later (in merics or infer) this is cast to int anyway. Also, we have no
# loss beyond cutoff = 0.2 as these are already correct predictions.
targets = tf.to_float(targets) + 0.5
loss = video_l2_internal_loss(logits, targets, model_hparams)
return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)
def video_l2_raw_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
del model_hparams, vocab_size, weights_fn # unused arg
prediction, groundtruth = convert_rgb_to_real(top_out, targets)
loss = tf.losses.mean_squared_error(prediction, groundtruth)
return loss, tf.constant(1.0)
def video_l1_raw_loss(top_out, targets, model_hparams, vocab_size, weights_fn):
del model_hparams, vocab_size, weights_fn # unused arg
prediction, groundtruth = convert_rgb_to_real(top_out, targets)
loss = tf.losses.absolute_difference(prediction, groundtruth)
return loss, tf.constant(1.0)
# Top transformations, applied to target features
def is_pointwise(func):
"""Decorator for whether the function is pointwise.
An example of a pointwise function is a linear layer followed by
a softmax. Given a tensor [batch, length, height, depth] it operates
only on the last axis, on every point in [batch, length, height] fully
independently. In contrast, a classifier that first averages over length
and height is not pointwise, as it depends on the whole field. It is useful
to know if top functions are pointwise to speed up decoding in certain models.
Args:
func: Function to decorate.
Returns:
Original function with an attribute pointwise set to True.
"""
func.pointwise = True
return func
def class_label_top(body_output, targets, model_hparams, vocab_size):
"""Transform inputs from model space to target space.
Average over inner dims and a linear layer to logits.
Args:
body_output: A Tensor with shape [batch, ?, ?, body_output_size].
targets:
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
a Tensors, each with shape [batch_size, 1, 1, 1, vocab_size]
"""
del targets # unused arg
with tf.variable_scope("class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
x = body_output
x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
res = tf.layers.dense(x, vocab_size)
return tf.expand_dims(res, 3)
def identity_top(body_output, targets, model_hparams, vocab_size):
del targets, model_hparams, vocab_size # unused arg
return body_output
def image_top(body_output, targets, model_hparams, vocab_size):
"""Top transformation for images."""
del targets # unused arg
# TODO(lukaszkaiser): is this a universal enough way to get channels?
num_channels = model_hparams.problem.num_channels
with tf.variable_scope("rgb_softmax"):
body_output_shape = common_layers.shape_list(body_output)
reshape_shape = body_output_shape[:3]
reshape_shape.extend([num_channels, vocab_size])
res = tf.layers.dense(body_output, vocab_size * num_channels)
res = tf.reshape(res, reshape_shape)
if not tf.get_variable_scope().reuse:
res_argmax = tf.argmax(res, axis=-1)
tf.summary.image(
"result",
common_layers.tpu_safe_image_summary(res_argmax),
max_outputs=1)
return res
def image_channel_compress_top(body_output, targets, model_hparams, vocab_size):
"""Transforms body output to return logits.
Args:
body_output: Tensor of shape [batch, img_len, img_len, depth].
targets:
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
Tensor of shape [batch, img_len, img_len, channels, vocab_size].
"""
del targets # unused arg
with tf.variable_scope("image_channel_compress_modality"):
hidden_size = model_hparams.hidden_size
img_len = model_hparams.img_len
channels = 3 # RGB
batch = common_layers.shape_list(body_output)[0]
x = tf.layers.conv2d(
body_output,
hidden_size * channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
activation=tf.nn.relu,
name="decompress_conv")
x = tf.reshape(x, [batch, img_len, img_len * channels, hidden_size])
x = common_layers.layer_preprocess(x, model_hparams)
x = tf.layers.dense(x,
vocab_size,
use_bias=True,
activation=None,
name="output_conv")
x = tf.reshape(
x, [batch, img_len, img_len, channels, vocab_size])
return x
def image_channel_embeddings_top(body_output,
targets,
model_hparams,
vocab_size):
"""Top transformation for images."""
del targets # unused arg
with tf.variable_scope("image_channel_embeddings_bottom"):
img_len = model_hparams.img_len
channels = model_hparams.num_channels
x = tf.layers.dense(
body_output, 256, use_bias=True, activation=None, name="output_conv")
x = tf.reshape(x,
[-1, img_len, img_len, channels, vocab_size])
return x
@is_pointwise
def real_top(body_output, targets, model_hparams, vocab_size):
del targets, model_hparams # unused arg
with tf.variable_scope("real"):
return tf.layers.dense(body_output, vocab_size, name="top")
def sigmoid_max_pooling_class_label_top(body_output,
targets,
model_hparams,
vocab_size):
"""Transform inputs from model space to target space.
Average over inner dims and a linear layer to logits.
Args:
body_output: A Tensor with shape [batch, timesteps, 1, body_output_size].
targets:
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
a Tensors, each with shape [batch_size, 1, 1, vocab_size]
"""
del targets # unused arg
with tf.variable_scope(
"sigmoid_max_pooling_class_symbol_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
x = body_output
x = tf.reduce_max(x, axis=1, keepdims=True)
return tf.layers.dense(x, vocab_size)
def softmax_average_pooling_class_label_top(body_output,
targets,
model_hparams,
vocab_size):
"""Loss for class label."""
del targets # unused arg
with tf.variable_scope(
"softmax_average_pooling_onehot_class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
x = body_output
x = tf.reduce_mean(x, axis=1, keepdims=True)
return tf.layers.dense(x, vocab_size)
def softmax_last_timestep_class_label_top(body_output,
targets,
model_hparams,
vocab_size):
"""Loss for class label."""
del targets # unused arg
with tf.variable_scope(
"softmax_last_timestep_onehot_class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
x = body_output
x = tf.expand_dims(x[:, -1], 1) # Pick the last timestep
return tf.layers.dense(x, vocab_size)
def softmax_max_pooling_class_label_top(body_output,
targets,
model_hparams,
vocab_size):
"""Loss for class label."""
del targets # unused arg
with tf.variable_scope(
"softmax_max_pooling_onehot_class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)):
x = body_output
x = tf.reduce_max(x, axis=1, keepdims=True)
return tf.layers.dense(x, vocab_size)
@is_pointwise
def symbol_top(body_output, targets, model_hparams, vocab_size):
"""Generate logits.
Args:
body_output: A Tensor with shape
[batch, p0, p1, model_hparams.hidden_size].
targets: Unused.
model_hparams: HParams, model hyperparmeters.
vocab_size: int, vocabulary size.
Returns:
logits: A Tensor with shape [batch, p0, p1, ?, vocab_size].
"""
del targets # unused arg
if model_hparams.shared_embedding_and_softmax_weights:
scope_name = "shared"
reuse = tf.AUTO_REUSE
else:
scope_name = "softmax"
reuse = False
with tf.variable_scope(scope_name, reuse=reuse):
body_output_shape = common_layers.shape_list(body_output)
var = get_weights(model_hparams, vocab_size, body_output_shape[-1])
if (model_hparams.factored_logits and
model_hparams.mode == tf_estimator.ModeKeys.TRAIN):
# insert channels dimension
body_output = tf.expand_dims(body_output, 3)
return common_layers.FactoredTensor(body_output, var)
else:
body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
logits = tf.matmul(body_output, var, transpose_b=True)
return tf.reshape(logits,
body_output_shape[:-1] + [1, vocab_size])
@is_pointwise
def symbol_one_hot_top(body_output, targets, model_hparams, vocab_size):
del targets, model_hparams, vocab_size # unused arg
return body_output
def video_top(body_output, targets, model_hparams, vocab_size):
"""Top transformation for video."""
del targets # unused arg
num_channels = model_hparams.problem.num_channels
shape = common_layers.shape_list(body_output)
reshape_shape = shape[:-1] + [num_channels, vocab_size]
res = tf.reshape(body_output, reshape_shape)
# Calculate argmax so as to have a summary with the produced images.
x = tf.argmax(tf.reshape(res, [-1, vocab_size]), axis=-1)
x = tf.reshape(x, shape[:-1] + [num_channels])
common_video.gif_summary("results", x, max_outputs=1)
return res
def video_l1_top(body_output, targets, model_hparams, vocab_size):
"""Top transformation for video."""
del targets, vocab_size # unused arg
num_channels = model_hparams.problem.num_channels
num_frames = model_hparams.video_num_target_frames
with tf.variable_scope("rgb"):
body_output_shape = common_layers.shape_list(body_output)
res = tf.layers.dense(body_output, num_channels * num_frames, name="cast")
res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames])
res = tf.transpose(res, [0, 4, 1, 2, 3]) # Move frames next to batch.
if not tf.get_variable_scope().reuse:
res_argmax = res[:, -1, :, :, :]
tf.summary.image(
"result",
common_layers.tpu_safe_image_summary(res_argmax),
max_outputs=1)
return tf.expand_dims(res, axis=-1) # Add an axis like in perplexity.
def video_raw_top(body_output, targets, model_hparams, vocab_size):
del targets, model_hparams, vocab_size # unused arg
frames = body_output
if isinstance(body_output, list):
frames = tf.stack(body_output, axis=1)
rgb_frames = common_layers.convert_real_to_rgb(frames)
common_video.gif_summary("body_output", rgb_frames)
return tf.expand_dims(rgb_frames, axis=-1)
# Utility functions similar to tf.keras for default transformations
def get_bottom(modality_type, value=None):
"""Gets default bottom transformation; if none available, return value."""
if modality_type == ModalityType.AUDIO:
return audio_bottom
elif modality_type == ModalityType.AUDIO_SPECTRAL:
return audio_spectral_bottom
elif modality_type in (ModalityType.CLASS_LABEL,
ModalityType.MULTI_LABEL,
ModalityType.ONE_HOT_CLASS_LABEL,
ModalityType.SIGMOID_CLASS_LABEL,
ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL):
return class_label_bottom
elif modality_type in (ModalityType.CTC_SYMBOL,
ModalityType.SYMBOL,
ModalityType.SYMBOL_WEIGHTS_ALL):
return symbol_bottom
elif modality_type in (ModalityType.GENERIC_L2_LOSS,
ModalityType.IDENTITY,
ModalityType.IDENTITY_SYMBOL,
ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM):
return identity_bottom
elif modality_type == ModalityType.IMAGE:
return image_bottom
elif modality_type in (ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
ModalityType.IMAGE_CHANNEL_COMPRESS):
return image_channel_compress_bottom
elif modality_type in (ModalityType.REAL,
ModalityType.REAL_L2_LOSS,
ModalityType.REAL_LOG_POISSON_LOSS):
return real_bottom
elif modality_type == ModalityType.SPEECH_RECOGNITION:
return speech_recognition_bottom
elif modality_type == ModalityType.SYMBOL_ONE_HOT:
return symbol_one_hot_bottom
elif modality_type in (ModalityType.VIDEO,
ModalityType.VIDEO_L1,
ModalityType.VIDEO_L2):
return video_bottom
elif modality_type == ModalityType.VIDEO_BITWISE:
return video_bitwise_bottom
elif modality_type == ModalityType.VIDEO_IDENTITY:
return video_identity_bottom
elif modality_type in (ModalityType.VIDEO_L1_RAW,
ModalityType.VIDEO_L2_RAW):
return video_raw_bottom
elif modality_type == ModalityType.VIDEO_PIXEL_NOISE:
return video_pixel_noise_bottom
return value
def get_loss(modality_type, value=None):
"""Gets default loss transformation; if none available, return value."""
if modality_type in (ModalityType.AUDIO,
ModalityType.AUDIO_SPECTRAL,
ModalityType.CLASS_LABEL,
ModalityType.IDENTITY,
ModalityType.IDENTITY_SYMBOL,
ModalityType.IMAGE,
ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
ModalityType.IMAGE_CHANNEL_COMPRESS,
ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM,
ModalityType.REAL,
ModalityType.SPEECH_RECOGNITION,
ModalityType.SYMBOL,
ModalityType.SYMBOL_WEIGHTS_ALL):
return generic_loss
elif modality_type == ModalityType.CTC_SYMBOL:
return ctc_symbol_loss
elif modality_type == ModalityType.GENERIC_L2_LOSS:
return generic_l2_loss
elif modality_type == ModalityType.MULTI_LABEL:
return multi_label_loss
elif modality_type in (ModalityType.ONE_HOT_CLASS_LABEL,
ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL):
return one_hot_class_label_loss
elif modality_type == ModalityType.REAL_L2_LOSS:
return real_l2_loss
elif modality_type == ModalityType.REAL_LOG_POISSON_LOSS:
return real_log_poisson_loss
elif modality_type == ModalityType.SIGMOID_CLASS_LABEL:
return sigmoid_class_label_loss
elif modality_type == ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL:
return sigmoid_max_pooling_class_label_loss
elif modality_type == ModalityType.SYMBOL_ONE_HOT:
return symbol_one_hot_loss
elif modality_type in (ModalityType.VIDEO,
ModalityType.VIDEO_BITWISE,
ModalityType.VIDEO_PIXEL_NOISE):
return video_loss
elif modality_type == ModalityType.VIDEO_IDENTITY:
return video_identity_loss
elif modality_type == ModalityType.VIDEO_L1:
return video_l1_loss
elif modality_type == ModalityType.VIDEO_L1_RAW:
return video_l1_raw_loss
elif modality_type == ModalityType.VIDEO_L2:
return video_l2_loss
elif modality_type == ModalityType.VIDEO_L2_RAW:
return video_l2_raw_loss
return value
def get_name(modality_type, value=None):
"""Gets default name for transformations; if none available, return value."""
# For legacy reasons, modalities vary in their naming scheme. Future plans are
# to remove any need for get_name. We do not recommend using it.
if modality_type == ModalityType.AUDIO:
return lambda model_hparams, vocab_size: "audio_modality"
elif modality_type == ModalityType.AUDIO_SPECTRAL:
return lambda model_hparams, vocab_size: "audio_spectral_modality"
elif modality_type == ModalityType.GENERIC_L2_LOSS:
return lambda model_hparams, vocab_size: "generic_l2_loss_modality"
elif modality_type == ModalityType.IDENTITY:
return lambda model_hparams, vocab_size: "identity_modality"
elif modality_type == ModalityType.IMAGE:
return lambda model_hparams, vocab_size: "image_modality"
elif modality_type == ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY:
return (lambda model_hparams, vocab_size: # pylint: disable=g-long-lambda
"image_channel_bottom_identity_modality")
elif modality_type == ModalityType.IMAGE_CHANNEL_COMPRESS:
return lambda model_hparams, vocab_size: "image_channel_compress_modality"
elif modality_type == ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM:
return lambda model_hparams, vocab_size: "image_channel_embeddings_bottom"
elif modality_type == ModalityType.REAL:
return lambda model_hparams, vocab_size: "real_modality"
elif modality_type == ModalityType.REAL_L2_LOSS:
return lambda model_hparams, vocab_size: "real_l2_loss_modality"
elif modality_type == ModalityType.REAL_LOG_POISSON_LOSS:
return lambda model_hparams, vocab_size: "real_log_poisson_loss_modality"
elif modality_type == ModalityType.SPEECH_RECOGNITION:
return lambda model_hparams, vocab_size: "speech_recognition_modality"
elif modality_type == ModalityType.VIDEO:
return lambda model_hparams, vocab_size: "video_modality"
elif modality_type == ModalityType.VIDEO_BITWISE:
return lambda model_hparams, vocab_size: "video_modality_bitwise"
elif modality_type == ModalityType.VIDEO_IDENTITY:
return lambda model_hparams, vocab_size: "video_modality_identity"
elif modality_type == ModalityType.VIDEO_L1:
return lambda model_hparams, vocab_size: "video_modality_l1"
elif modality_type == ModalityType.VIDEO_L1_RAW:
return lambda model_hparams, vocab_size: "video_modality_l1_raw"
elif modality_type == ModalityType.VIDEO_L2:
return lambda model_hparams, vocab_size: "video_modality_l2"
elif modality_type == ModalityType.VIDEO_L2_RAW:
return lambda model_hparams, vocab_size: "video_modality_l2_raw"
elif modality_type == ModalityType.VIDEO_PIXEL_NOISE:
return lambda model_hparams, vocab_size: "video_modality_pixel_noise"
elif modality_type in (ModalityType.CLASS_LABEL,
ModalityType.MULTI_LABEL,
ModalityType.ONE_HOT_CLASS_LABEL):
def name(model_hparams, vocab_size):
return "class_label_modality_%d_%d" % (vocab_size,
model_hparams.hidden_size)
return name
elif modality_type in (ModalityType.CTC_SYMBOL,
ModalityType.IDENTITY_SYMBOL,
ModalityType.SYMBOL,
ModalityType.SYMBOL_WEIGHTS_ALL,
ModalityType.SYMBOL_ONE_HOT):
def name(model_hparams, vocab_size):
return "symbol_modality_%d_%d" % (vocab_size, model_hparams.hidden_size)
return name
elif modality_type == ModalityType.SIGMOID_CLASS_LABEL:
def name(model_hparams, vocab_size):
return "sigmoid_class_symbol_modality_%d_%d" % (vocab_size,
model_hparams.hidden_size)
return name
elif modality_type == ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL:
def name(model_hparams, vocab_size):
return "sigmoid_max_pooling_class_symbol_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)
return name
elif modality_type == ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL:
def name(model_hparams, vocab_size):
return "softmax_average_pooling_onehot_class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)
return name
elif modality_type == ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL:
def name(model_hparams, vocab_size):
return "softmax_last_timestep_onehot_class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)
return name
elif modality_type == ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL:
def name(model_hparams, vocab_size):
return "softmax_max_pooling_onehot_class_label_modality_%d_%d" % (
vocab_size, model_hparams.hidden_size)
return name
return value
def get_targets_bottom(modality_type, value=None):
"""Gets default bottom transformation for targets; if none, return value."""
if modality_type == ModalityType.AUDIO:
return make_targets_bottom(audio_bottom)
elif modality_type == ModalityType.AUDIO_SPECTRAL:
return make_targets_bottom(audio_spectral_bottom)
elif modality_type in (ModalityType.CLASS_LABEL,
ModalityType.MULTI_LABEL,
ModalityType.ONE_HOT_CLASS_LABEL,
ModalityType.SIGMOID_CLASS_LABEL,
ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL,
ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL,
ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL):
return class_label_targets_bottom
elif modality_type in (ModalityType.CTC_SYMBOL,
ModalityType.SYMBOL,
ModalityType.SYMBOL_WEIGHTS_ALL):
return symbol_targets_bottom
elif modality_type in (ModalityType.GENERIC_L2_LOSS,
ModalityType.IDENTITY_SYMBOL):
return identity_bottom
elif modality_type == ModalityType.IDENTITY:
return make_targets_bottom(identity_bottom)
elif modality_type == ModalityType.IMAGE:
return image_targets_bottom
elif modality_type in (ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
ModalityType.IMAGE_CHANNEL_COMPRESS):
return image_channel_compress_targets_bottom
elif modality_type == ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM:
return image_channel_embeddings_bottom
elif modality_type in (ModalityType.REAL,
ModalityType.REAL_L2_LOSS,
ModalityType.REAL_LOG_POISSON_LOSS):
return make_targets_bottom(real_bottom)
elif modality_type == ModalityType.SPEECH_RECOGNITION:
return make_targets_bottom(speech_recognition_bottom)
elif modality_type == ModalityType.SYMBOL_ONE_HOT:
return symbol_one_hot_bottom
elif modality_type in (ModalityType.VIDEO,
ModalityType.VIDEO_L1,
ModalityType.VIDEO_L2):
return video_targets_bottom
elif modality_type == ModalityType.VIDEO_BITWISE:
return video_bitwise_targets_bottom
elif modality_type == ModalityType.VIDEO_IDENTITY:
return video_identity_targets_bottom
elif modality_type in (ModalityType.VIDEO_L1_RAW,
ModalityType.VIDEO_L2_RAW):
return video_raw_targets_bottom
elif modality_type == ModalityType.VIDEO_PIXEL_NOISE:
return make_targets_bottom(video_pixel_noise_bottom)
return value
def get_top(modality_type, value=None):
"""Gets default top transformation; if none available, return value."""
if modality_type in (ModalityType.AUDIO,
ModalityType.AUDIO_SPECTRAL,
ModalityType.GENERIC_L2_LOSS,
ModalityType.IDENTITY,
ModalityType.IDENTITY_SYMBOL,
ModalityType.IMAGE_CHANNEL_BOTTOM_IDENTITY,
ModalityType.SPEECH_RECOGNITION,
ModalityType.VIDEO_IDENTITY):
return identity_top
elif modality_type in (ModalityType.CLASS_LABEL,
ModalityType.MULTI_LABEL,
ModalityType.ONE_HOT_CLASS_LABEL,
ModalityType.SIGMOID_CLASS_LABEL):
return class_label_top
elif modality_type in (ModalityType.CTC_SYMBOL,
ModalityType.SYMBOL,
ModalityType.SYMBOL_WEIGHTS_ALL):
return symbol_top
elif modality_type == ModalityType.IMAGE:
return image_top
elif modality_type == ModalityType.IMAGE_CHANNEL_COMPRESS:
return image_channel_compress_top
elif modality_type == ModalityType.IMAGE_CHANNEL_EMBEDDINGS_BOTTOM:
return image_channel_embeddings_top
elif modality_type in (ModalityType.REAL,
ModalityType.REAL_L2_LOSS,
ModalityType.REAL_LOG_POISSON_LOSS):
return real_top
elif modality_type == ModalityType.SIGMOID_MAX_POOLING_CLASS_LABEL:
return sigmoid_max_pooling_class_label_top
elif modality_type == ModalityType.SOFTMAX_AVERAGE_POOLING_CLASS_LABEL:
return softmax_average_pooling_class_label_top
elif modality_type == ModalityType.SOFTMAX_LAST_TIMESTEP_CLASS_LABEL:
return softmax_last_timestep_class_label_top
elif modality_type == ModalityType.SOFTMAX_MAX_POOLING_CLASS_LABEL:
return softmax_max_pooling_class_label_top
elif modality_type == ModalityType.SYMBOL_ONE_HOT:
return symbol_one_hot_top
elif modality_type in (ModalityType.VIDEO,
ModalityType.VIDEO_BITWISE,
ModalityType.VIDEO_PIXEL_NOISE):
return video_top
elif modality_type in (ModalityType.VIDEO_L1,
ModalityType.VIDEO_L2):
return video_l1_top
elif modality_type in (ModalityType.VIDEO_L1_RAW,
ModalityType.VIDEO_L2_RAW):
return video_raw_top
return value
def get_weights_fn(modality_type, value=None):
"""Gets default weights function; if none available, return value."""
if modality_type in (ModalityType.CTC_SYMBOL,
ModalityType.IDENTITY_SYMBOL,
ModalityType.MULTI_LABEL,
ModalityType.SYMBOL,
ModalityType.SYMBOL_ONE_HOT):
return common_layers.weights_nonzero
elif modality_type in ModalityType.get_choices():
return common_layers.weights_all
return value
================================================
FILE: tensor2tensor/layers/modalities_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Modalities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import modalities
from tensor2tensor.utils import expert_utils
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
tf.enable_eager_execution()
class ModalityTest(tf.test.TestCase):
@test_utils.run_in_graph_and_eager_modes()
def testGetForAllModalities(self):
for modality in modalities.ModalityType.get_choices():
bottom = modalities.get_bottom(modality)
loss = modalities.get_loss(modality)
name = modalities.get_name(modality)
targets_bottom = modalities.get_targets_bottom(modality)
top = modalities.get_top(modality)
weights_fn = modalities.get_weights_fn(modality)
self.assertIsNotNone(bottom,
msg="{} has no default bottom".format(modality))
self.assertIsNotNone(loss, msg="{} has no default loss".format(modality))
self.assertIsNotNone(name, msg="{} has no default name".format(modality))
self.assertIsNotNone(
targets_bottom,
msg="{} has no default targets_bottom".format(modality))
self.assertIsNotNone(top, msg="{} has no default top".format(modality))
self.assertIsNotNone(weights_fn,
msg="{} has no default weights_fn".format(modality))
@test_utils.run_in_graph_and_eager_modes()
def testSymbolModalityInputs(self):
batch_size = 10
num_datashards = 5
length = 5
vocab_size = 5000
hidden_size = 9
model_hparams = common_hparams.basic_params1()
model_hparams.hidden_size = hidden_size
model_hparams.mode = tf_estimator.ModeKeys.TRAIN
x = np.random.randint(
vocab_size, size=(batch_size, length, 1, 1))
data_parallelism = expert_utils.Parallelism(
["/device:CPU:0"] * num_datashards)
xs = tf.split(x, num_datashards)
sharded_output = data_parallelism(
modalities.get_bottom(modalities.ModalityType.SYMBOL),
xs,
model_hparams,
vocab_size)
output = tf.concat(sharded_output, 0)
self.evaluate(tf.global_variables_initializer())
res = self.evaluate(output)
self.assertEqual(res.shape, (batch_size, length, 1, hidden_size))
@test_utils.run_in_graph_and_eager_modes()
def testSymbolModalityTargets(self):
batch_size = 10
num_datashards = 5
length = 6
height = 7
hidden_size = 9
vocab_size = 11
model_hparams = common_hparams.basic_params1()
model_hparams.hidden_size = hidden_size
model_hparams.mode = tf_estimator.ModeKeys.TRAIN
body_output = np.random.randint(
100, size=(batch_size, length, height, hidden_size))
targets = np.random.randint(
vocab_size, size=(batch_size, length, height, 1))
data_parallelism = expert_utils.Parallelism(
["/device:CPU:0"] * num_datashards)
sharded_body_output = tf.split(tf.to_float(body_output), num_datashards)
sharded_targets = tf.split(targets, num_datashards)
sharded_logits = data_parallelism(
modalities.get_top(modalities.ModalityType.SYMBOL),
sharded_body_output,
sharded_targets,
model_hparams,
vocab_size)
sharded_loss_num, sharded_loss_den = data_parallelism(
modalities.get_loss(modalities.ModalityType.SYMBOL),
sharded_logits,
sharded_targets,
model_hparams,
vocab_size,
modalities.get_weights_fn(modalities.ModalityType.SYMBOL))
train_loss = (tf.add_n(sharded_loss_num) /
tf.maximum(1.0, tf.add_n(sharded_loss_den)))
logits = tf.concat(sharded_logits, 0)
self.evaluate(tf.global_variables_initializer())
res1, res2 = self.evaluate((logits, train_loss))
self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size))
self.assertEqual(res2.shape, ())
@test_utils.run_in_graph_mode_only()
def testSymbolModalityTargetsFactored(self):
batch_size = 10
num_datashards = 5
length = 6
height = 7
hidden_size = 9
vocab_size = 11
model_hparams = common_hparams.basic_params1()
model_hparams.factored_logits = True
model_hparams.hidden_size = hidden_size
model_hparams.mode = tf_estimator.ModeKeys.TRAIN
body_output = np.random.randint(
100, size=(batch_size, length, height, hidden_size))
targets = np.random.randint(
vocab_size, size=(batch_size, length, height, 1))
data_parallelism = expert_utils.Parallelism(
["/device:CPU:0"] * num_datashards)
with self.test_session() as session:
sharded_body_output = tf.split(tf.to_float(body_output), num_datashards)
sharded_targets = tf.split(targets, num_datashards)
sharded_logits = data_parallelism(
modalities.get_top(modalities.ModalityType.SYMBOL),
sharded_body_output,
sharded_targets,
model_hparams,
vocab_size)
sharded_loss_num, sharded_loss_den = data_parallelism(
modalities.get_loss(modalities.ModalityType.SYMBOL),
sharded_logits,
sharded_targets,
model_hparams,
vocab_size,
modalities.get_weights_fn(modalities.ModalityType.SYMBOL))
train_loss = (tf.add_n(sharded_loss_num) /
tf.maximum(1.0, tf.add_n(sharded_loss_den)))
logits = tf.concat(sharded_logits, 0)
session.run(tf.global_variables_initializer())
res1, res2 = session.run((logits, train_loss))
self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size))
self.assertEqual(res2.shape, ())
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/ngram.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""N-gram layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf
class NGram(tf.keras.layers.Layer):
r"""N-gram layer.
The layer takes as input an integer Tensor of shape [..., length], each
element of which is a token index in [0, input_dim). It returns a real-valued
Tensor of shape [..., num_ngrams], counting the number of times each n-gram
appears in a batch element. The total number of n-grams is
```none
num_ngrams = \sum_{minval <= n < maxval} input_dim^n.
```
"""
def __init__(self, input_dim, minval, maxval, **kwargs):
"""Constructs layer.
Args:
input_dim: int > 0. Size of the vocabulary, i.e. maximum integer index +
1.
minval: Lowest inclusive value of n for computing n-grams. For example,
setting it to 1 will compute starting from unigrams.
maxval: Highest non-inclusive value of n for computing n-grams. For
example, setting it to 3 will compute at most bigrams.
**kwargs: kwargs of parent class.
"""
super(NGram, self).__init__(**kwargs)
self.input_dim = input_dim
self.minval = minval
self.maxval = maxval
def call(self, inputs):
batch_shape = tf.shape(inputs)[:-1]
length = tf.shape(inputs)[-1]
ngram_range_counts = []
for n in range(self.minval, self.maxval):
# Reshape inputs from [..., length] to [..., 1, length // n, n], dropping
# remainder elements. Each n-vector is an ngram.
reshaped_inputs = tf.reshape(
inputs[..., :(n * (length // n))],
tf.concat([batch_shape, [1], (length // n)[tf.newaxis], [n]], 0))
# Count the number of times each ngram appears in the input. We do so by
# checking whether each n-vector in the input is equal to each n-vector
# in a Tensor of all possible ngrams. The comparison is batched between
# the input Tensor of shape [..., 1, length // n, n] and the ngrams Tensor
# of shape [..., input_dim**n, 1, n].
ngrams = tf.reshape(
list(np.ndindex((self.input_dim,) * n)),
[1] * (len(inputs.shape)-1) + [self.input_dim**n, 1, n])
is_ngram = tf.equal(
tf.reduce_sum(tf.cast(tf.equal(reshaped_inputs, ngrams), tf.int32),
axis=-1),
n)
ngram_counts = tf.reduce_sum(tf.cast(is_ngram, tf.float32), axis=-1)
ngram_range_counts.append(ngram_counts)
return tf.concat(ngram_range_counts, axis=-1)
def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape)
num_ngrams = sum([self.input_dim**n
for n in range(self.minval, self.maxval)])
return input_shape[:-1].concatenate(num_ngrams)
def get_config(self):
config = {'minval': self.minval,
'maxval': self.maxval}
base_config = super(NGram, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
================================================
FILE: tensor2tensor/layers/ngram_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for n-gram layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import ngram
from tensor2tensor.utils import test_utils
import tensorflow.compat.v1 as tf
tf.enable_eager_execution()
class NGramTest(tf.test.TestCase):
@test_utils.run_in_graph_and_eager_modes()
def testNGramLayerShape(self):
batch_size = 2
length = 8
vocab_size = 3
minval = 1
maxval = 4
inputs = tf.random_uniform(
[batch_size, length], minval=0, maxval=vocab_size, dtype=tf.int32)
layer = ngram.NGram(vocab_size, minval, maxval)
outputs = layer(inputs)
outputs_val = self.evaluate(outputs)
num_ngrams = sum([vocab_size**n for n in range(minval, maxval)])
self.assertEqual(outputs_val.shape, (batch_size, num_ngrams))
@test_utils.run_in_graph_and_eager_modes()
def testNGramLayerOutput(self):
inputs = tf.constant(
[[0, 0, 0, 0, 1],
[2, 1, 2, 1, 0]], dtype=tf.int32)
layer = ngram.NGram(3, minval=1, maxval=3)
outputs = layer(inputs)
expected_outputs = tf.constant(
[[4., 1., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 2., 2., 0., 0., 0., 0., 0., 0., 0., 2., 0.]], dtype=tf.float32)
outputs_val, expected_outputs_val = self.evaluate(
[outputs, expected_outputs])
self.assertAllEqual(outputs_val, expected_outputs_val)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/transformer_glow_layers.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Glow operations for text.
Adapted glow operations from tensor2tensor.models.research.glow_ops to be used
as a prior in Text VAEs (specifically for MT). Supports:
1. Log determinant Jacobian computation with variable length data and masking.
2. Transformer instead of convolution as a basic transformation.
3. Every transformation (affine, split) conditions on the source
sentence.
4. Three different split functions in affine coupling.
5. Multi-head 1x1 convolution.
6. Actnorm with weight normalization.
Implementation based on Ma et al., 2019: https://arxiv.org/pdf/1909.02480.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import scipy
from tensor2tensor.layers import common_layers
import tensor2tensor.layers.transformer_glow_layers_ops as gops
import tensorflow.compat.v1 as tf
def actnorm(name, x, x_mask, inverse, init, logscale_factor=3.0):
"""Activation normalization, returns logabsdet of shape [B]."""
eps = tf.keras.backend.epsilon()
n_channels = common_layers.shape_list(x)[2]
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
x_mean, x_var = gops.moments_over_bl(x, x_mask)
b = gops.get_variable_ddi(
"b", (n_channels), -x_mean, init, tf.zeros_initializer)
log_w_init = -0.5 * tf.log(x_var + eps) / logscale_factor
log_w = gops.get_variable_ddi(
"log_w", (n_channels), log_w_init, init,
tf.zeros_initializer) * logscale_factor
if not inverse:
x = (x + b) * tf.exp(log_w)
else:
x = x * tf.exp(-log_w) - b
x_length = tf.reduce_sum(x_mask, -1)
logabsdet = x_length * tf.reduce_sum(log_w)
if inverse:
logabsdet *= -1
return x, logabsdet
def multihead_invertible_1x1_conv_np(
name, x, x_mask, multihead_split, inverse, dtype):
"""Multi-head 1X1 convolution on x."""
batch_size, length, n_channels_all = common_layers.shape_list(x)
assert n_channels_all % 32 == 0
n_channels = 32
n_1x1_heads = n_channels_all // n_channels
def get_init_np():
"""Initializer function for multihead 1x1 parameters using numpy."""
results = []
for _ in range(n_1x1_heads):
random_matrix = np.random.rand(n_channels, n_channels)
np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")
np_p, np_l, np_u = scipy.linalg.lu(np_w)
np_s = np.diag(np_u)
np_sign_s = np.sign(np_s)[np.newaxis, :]
np_log_s = np.log(np.abs(np_s))[np.newaxis, :]
np_u = np.triu(np_u, k=1)
results.append(
np.concatenate([np_p, np_l, np_u, np_sign_s, np_log_s], axis=0))
return tf.convert_to_tensor(np.stack(results, axis=0))
def get_mask_init():
ones = tf.ones([n_1x1_heads, n_channels, n_channels], dtype=dtype)
l_mask = tf.matrix_band_part(ones, -1, 0) - tf.matrix_band_part(ones, 0, 0)
u_mask = tf.matrix_band_part(ones, 0, -1) - tf.matrix_band_part(ones, 0, 0)
return tf.stack([l_mask, u_mask], axis=0)
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
params = tf.get_variable("params", initializer=get_init_np, dtype=dtype)
mask_params = tf.get_variable(
"mask_params", initializer=get_mask_init, dtype=dtype, trainable=False)
p = tf.stop_gradient(params[:, :n_channels, :])
l = params[:, n_channels : 2*n_channels, :]
u = params[:, 2*n_channels : 3*n_channels, :]
sign_s = tf.stop_gradient(params[:, 3*n_channels, :])
log_s = params[:, 3*n_channels+1, :]
l_mask = mask_params[0]
u_mask = mask_params[1]
l_diag = l * l_mask + (
tf.eye(n_channels, n_channels, [n_1x1_heads], dtype=dtype))
u_diag = u * u_mask + (
tf.matrix_diag(sign_s * tf.exp(log_s)))
w = tf.matmul(p, tf.matmul(l_diag, u_diag))
if multihead_split == "a":
x = tf.reshape(x, [batch_size, length, n_channels, n_1x1_heads])
x = tf.transpose(x, [3, 0, 1, 2])
elif multihead_split == "c":
x = tf.reshape(x, [batch_size, length, n_1x1_heads, n_channels])
x = tf.transpose(x, [2, 0, 1, 3])
else:
raise ValueError("Multihead split not supported.")
# [n_1x1_heads, batch_size, length, n_channels]
if not inverse:
# [n_1x1_heads, 1, n_channels, n_channels]
x = tf.matmul(x, w[:, tf.newaxis, :, :])
else:
w_inv = tf.matrix_inverse(w)
x = tf.matmul(x, w_inv[:, tf.newaxis, :, :])
if multihead_split == "a":
x = tf.transpose(x, [1, 2, 3, 0])
x = tf.reshape(x, [batch_size, length, n_channels * n_1x1_heads])
elif multihead_split == "c":
x = tf.transpose(x, [1, 2, 0, 3])
x = tf.reshape(x, [batch_size, length, n_1x1_heads * n_channels])
else:
raise ValueError("Multihead split not supported.")
x_length = tf.reduce_sum(x_mask, -1)
logabsdet = x_length * tf.reduce_sum(log_s)
if inverse:
logabsdet *= -1
return x, logabsdet
def coupling(*args, **kwargs):
"""Coupling transform layer."""
prior_type = kwargs["hparams"].prior_type
posterior_type = kwargs["hparams"].posterior_type
if prior_type == "affine" or posterior_type == "affine":
return affine_coupling(*args, **kwargs)
elif prior_type == "additive" or posterior_type == "additive":
return additive_coupling(*args, **kwargs)
def additive_coupling(
name, x, x_mask, inverse, split_dim, identity_first, init,
decoder_self_attention_bias=None, **kwargs):
"""Additive coupling transform layer."""
hparams = kwargs["hparams"]
batch_size, length, n_channels = common_layers.shape_list(x)
assert hparams.scale_width > 0.0 and hparams.scale_width < 1.0
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
x_id, x_tr, _, n_transform, bias, mask = gops.split_coupling(
x, x_mask, split_dim, identity_first, decoder_self_attention_bias)
z_id = x_id
loc = gops.transformer_decoder_block(
"theta_tr",
n_layers=hparams.n_layers_transform_params,
x=x_id,
x_mask=mask,
output_size=n_transform,
init=init,
decoder_self_attention_bias=bias,
**kwargs)
if not inverse:
z_tr = x_tr + loc
else:
z_tr = x_tr - loc
logabsdet = tf.constant(0.0, dtype=tf.float32)
tf.summary.histogram("_loc", tf.boolean_mask(loc, mask))
result = gops.join_coupling(z_id, z_tr, split_dim, identity_first)
result = tf.reshape(result, [batch_size, length, n_channels])
return result, logabsdet
def affine_coupling(
name, x, x_mask, inverse, split_dim, identity_first, init,
decoder_self_attention_bias=None, **kwargs):
"""Affine coupling transform layer.
Args:
name: variable scope.
x: 3-D Tensor, shape=[B, L, C].
x_mask : 2-D Tensor, shape=[B, L].
inverse: Forward or inverse pass.
split_dim: which dimension to split
(time, channel_continuous, channel_alternate).
identity_first: True means the first half remains constant. False for 2nd.
init: init.
decoder_self_attention_bias: bias.
**kwargs: additional arguments. Contains hparams, encoder_output and
encoder_decoder_attention_bias.
Returns:
z: data transformed by the affine coupling layer. shape=[B, L, C]
logabsdets: Log absolute determinant Jacobian. shape=[B]
"""
hparams = kwargs["hparams"]
batch_size, length, n_channels = common_layers.shape_list(x)
assert hparams.scale_width > 0.0 and hparams.scale_width < 1.0
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
x_id, x_tr, _, n_transform, bias, mask = gops.split_coupling(
x, x_mask, split_dim, identity_first, decoder_self_attention_bias)
z_id = x_id
transform_params = gops.transformer_decoder_block(
"theta_tr",
n_layers=hparams.n_layers_transform_params,
x=x_id,
x_mask=mask,
output_size=n_transform*2,
init=init,
decoder_self_attention_bias=bias,
**kwargs)
loc, unconstrained_scale = tf.split(transform_params, 2, axis=-1)
scale = tf.sigmoid(unconstrained_scale + 2.0)
if not inverse:
z_tr = (x_tr + loc) * scale
else:
z_tr = x_tr / scale - loc
logabsdet = gops.reduce_sum_over_lc(tf.log(scale), mask) # [B]
if inverse:
logabsdet *= -1
tf.summary.histogram("_loc", tf.boolean_mask(loc, mask))
tf.summary.histogram("_scale", tf.boolean_mask(scale, mask))
result = gops.join_coupling(z_id, z_tr, split_dim, identity_first)
result = tf.reshape(result, [batch_size, length, n_channels])
return result, logabsdet
def flow_step_glow(name, x, x_mask, split_dims, inverse, init, dtype, **kwargs):
"""One step of flow."""
conv_fn = multihead_invertible_1x1_conv_np
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
reversible_ops = []
for _, split_dim in enumerate(split_dims):
identity_first = True
reversible_ops += [functools.partial(actnorm, name="actnorm", init=init)]
if split_dim in "ca":
multihead_split = "a" if split_dim == "c" else "c"
reversible_ops += [functools.partial(
conv_fn, name="conv_{}".format(multihead_split),
multihead_split=multihead_split, dtype=dtype)]
reversible_ops += [functools.partial(
coupling, name="coupling_{}".format(split_dim),
split_dim=split_dim, identity_first=identity_first, init=init,
**kwargs)]
if inverse:
reversible_ops = reversible_ops[::-1]
logabsdets = tf.constant(0.0, dtype=dtype)
for reversible_op in reversible_ops:
x, logabsdet = reversible_op(x=x, x_mask=x_mask, inverse=inverse)
logabsdets += logabsdet
return x, logabsdets
def flow_level(
name, x, x_mask, depth, split_dims, prior, inverse, init, dtype, **kwargs):
"""One level of flow."""
flow_step_fn = flow_step_glow
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
reversible_ops = []
for step in np.arange(depth):
reversible_ops += [functools.partial(
flow_step_fn, name="{}_step".format(step), split_dims=split_dims,
init=init, dtype=dtype, **kwargs)]
if prior:
reversible_ops += [functools.partial(
coupling, name="{}_prior".format(depth), split_dim="c",
identity_first=True, init=init, **kwargs)]
if inverse:
reversible_ops = reversible_ops[::-1]
logabsdets = tf.constant(0.0, dtype=dtype)
for reversible_op in reversible_ops:
x, logabsdet = reversible_op(x=x, x_mask=x_mask, inverse=inverse)
logabsdets += logabsdet
return x, logabsdets
def split(name, x, x_mask, inverse, temp=1.0, dtype=tf.float32, z=None):
"""Splits / concatenates x into x1 and x2 across number of channels.
x2 is modelled with a standard gaussian distribution.
Args:
name: variable scope.
x: 3-D Tensor, shape=[B, L, C].
x_mask: 2-D Tensor, shape=[B, L].
inverse: forward or inverse pass.
temp: only used for inverse pass. temperature for sampling.
dtype: dtype
z: used in inverse pass to check invertibility.
Returns:
x: if forward, returns the 1st half of the channel dimensions.
if inverse, return concat[input, N(0,1)]
z: second half of the channel dimensions. modelled as standard normal.
log_p: log p(x2; N(0,1)), shape=[B]
"""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
if not inverse:
x1, x2 = tf.split(x, 2, axis=-1)
log_p = gops.standard_normal_density(x2, x_mask)
return x1, x2, log_p
else:
if z is None:
x2 = tf.random.normal(
common_layers.shape_list(x), stddev=temp, dtype=dtype)
else:
x2 = z
log_p = gops.standard_normal_density(x2, x_mask)
return tf.concat([x, x2], 2), None, log_p
def squeeze(name, x, factor, inverse):
"""Temporal squeezing of x to increase the number of channels."""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
if factor == 1:
return x
batch_size, length, n_channels = common_layers.shape_list(x)
if not inverse:
x = tf.reshape(x, [batch_size, length//factor, factor, n_channels])
# transposing groups neighbouring elements together.
x = tf.transpose(x, [0, 1, 3, 2])
x = tf.reshape(x, [batch_size, length//factor, n_channels*factor])
else:
x = tf.reshape(x, (batch_size, length, n_channels//factor, factor))
x = tf.transpose(x, [0, 1, 3, 2])
x = tf.reshape(x, (batch_size, length*factor, n_channels//factor))
return x
def glow(
name, x, max_x_mask, max_self_attn_bias, inverse, init, dtype=tf.float32,
split_zs=None, temp=1.0, **kwargs):
"""Multi-scale glow model. Flow + (n_levels-1)*(Split + Squeeze + Flow).
Note the original glow's ordering is Squeeze + Flow + Split.
Args:
name: variable scope.
x: 3-D Tensor, shape=[B, L, C]. The length dimension is padded to the
closest multiple of factor**n_levels.
max_x_mask : 2-D Tensor, shape=[B, L]. Binary mask indicating padding.
max_self_attn_bias : 4-D Tensor, shape=[B, 1, 1, L].
inverse: forward or inverse pass.
init: init.
dtype: dtype.
split_zs: intermediate latents modelled as a standard normal.
temp: Only used in inverse. Temperature for sampling.
**kwargs: additional arguments. Contains hparams, disable_dropout,
encoder_output and encoder_decoder_attention_bias.
Returns:
x: if forward, data transformed to the base distribution.
if inverse, base transformed to the data (latent) distribution.
logabsdets: log absolute determinant Jacobian. [B]
log_ps: log probability in the base distribution. [B]
split_zs: all intermediate latents (only used to check invertibility.)
"""
assert x.shape.rank == 3
hparams = kwargs["hparams"]
factor = hparams.factor
if hparams.depths:
depths = [int(depth_str) for depth_str in hparams.depths.split("/")]
else:
depths = []
split_plans = hparams.split_plans.split("/")
n_levels = len(depths)
logabsdets = tf.constant(0.0, dtype=dtype)
log_ps = tf.constant(0.0, dtype=dtype)
with tf.variable_scope(name, use_resource=True, reuse=tf.AUTO_REUSE):
if not inverse: # z -> e (density estimation)
x_mask, self_attn_bias = max_x_mask, max_self_attn_bias
split_zs = []
for level in range(n_levels):
if level > 0:
x, z, log_p_z = split(
"{}_split".format(level), x, x_mask, inverse, dtype)
log_ps += log_p_z
split_zs.append(z)
x = squeeze("{}_squeeze".format(level), x, factor, inverse)
x_mask = max_x_mask[:, ::factor**level]
self_attn_bias = max_self_attn_bias[..., ::factor**level]
prior = level < n_levels - 1
x, logabsdet = flow_level(
"{}_level".format(level), x, x_mask, depths[level],
split_plans[level], prior, inverse, init, dtype,
decoder_self_attention_bias=self_attn_bias, **kwargs)
logabsdets += logabsdet # (B)
log_p_base = gops.standard_normal_density(x, x_mask)
log_ps += log_p_base
return x, logabsdets, log_ps, split_zs
else: # e -> z (sampling)
x_mask = max_x_mask[:, ::factor**(n_levels-1)]
log_p_base = gops.standard_normal_density(x, x_mask)
log_ps += log_p_base
if split_zs is None:
split_zs = [None] * (n_levels-1)
for level in reversed(range(n_levels)):
x_mask = max_x_mask[:, ::factor**level]
self_attn_bias = max_self_attn_bias[..., ::factor**level]
prior = level < n_levels - 1
x, logabsdet = flow_level(
"{}_level".format(level), x, x_mask, depths[level],
split_plans[level], prior, inverse, init, dtype,
decoder_self_attention_bias=self_attn_bias, **kwargs)
logabsdets += logabsdet
if level > 0:
x = squeeze("{}_squeeze".format(level), x, factor, inverse)
x_mask = max_x_mask[:, ::factor**(level-1)]
x, _, log_p_z = split(
"{}_split".format(level), x, x_mask, inverse, temp=temp,
dtype=dtype, z=split_zs[level-1])
log_ps += log_p_z
return x, logabsdets, log_ps, None
================================================
FILE: tensor2tensor/layers/transformer_glow_layers_ops.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Additional operations for transformer_glow_layers.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import math
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.models.transformer import transformer_decoder_layer
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
def dense(name, x, n_out, dtype=tf.float32, init_w=0.05):
"""Dense layer."""
n_in = common_layers.shape_list(x)[2]
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
w = tf.get_variable(
"w", [n_in, n_out], dtype,
initializer=tf.random_normal_initializer(0.0, init_w), trainable=True)
b = tf.get_variable(
"b", [n_out,], dtype, initializer=tf.zeros_initializer, trainable=True)
x = tf.matmul(x, w) + b
return x
def dense_weightnorm(
name, x, n_out, x_mask, init_scale, init, dtype=tf.float32):
"""Dense layer with weight normalization."""
n_in = common_layers.shape_list(x)[2]
eps = tf.keras.backend.epsilon()
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
v = tf.get_variable(
"v", [n_in, n_out], dtype,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
v = v / tf.norm(v, axis=0, keepdims=True)
t = tf.matmul(x, v) # [B, L, n_out]
mean, var = moments_over_bl(t, x_mask)
g_init = init_scale / (tf.sqrt(var) + eps)
g = get_variable_ddi(
"g", [n_out], g_init, init,
initializer=tf.zeros_initializer, dtype=dtype, trainable=True)
b = get_variable_ddi(
"b", [n_out], -mean*g_init, init,
initializer=tf.zeros_initializer, dtype=dtype, trainable=True)
w = g * v
y = tf.matmul(x, w) + b
tf.summary.histogram("_g", g)
return y
def transformer_decoder_block(name,
n_layers,
x,
x_mask,
output_size,
init,
**kwargs):
"""A transformation block composed of transformer decoder layers.
Args:
name: variable scope.
n_layers: number of transformer layers.
x: input to transformation.
x_mask: mask.
output_size: output dimensionality.
init: data-dependent init for weightnorm parameters.
**kwargs: Constains hparams, encoder_output,
encoder_decoder_attention_bias and decoder_self_attention_bias
Returns:
outputs: Tensor of shape [batch_size, length, output_size].
"""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
hparams = kwargs.pop("hparams")
disable_dropout = kwargs.pop("disable_dropout")
if disable_dropout:
hparams = copy.deepcopy(hparams)
hparams.attention_dropout = 0.0
hparams.layer_prepostprocess_dropout = 0.0
hparams.relu_dropout = 0.0
n_channels = common_layers.shape_list(x)[-1]
if n_channels != hparams.hidden_size:
hparams = copy.deepcopy(hparams)
hparams.hidden_size = n_channels
outputs = common_attention.add_timing_signal_1d(x)
with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
for layer_idx in range(n_layers):
outputs = transformer_decoder_layer(
decoder_input=outputs,
layer_idx=layer_idx,
hparams=hparams,
**kwargs)
outputs = common_layers.layer_preprocess(outputs, hparams)
outputs = dense_weightnorm(
"h2o", outputs, output_size, x_mask, init_scale=0.0, init=init)
return outputs
def reduce_sum_over_lc(x, x_mask):
"""Returns sum of x (over L and C) given the actual length and pad.
Args:
x: input. (B,L,C)
x_mask: binary padding mask. (B,L)
Returns:
sum of x. (B)
"""
if x.shape.rank == 3 and x_mask.shape.rank == 2:
x_mask = x_mask[..., tf.newaxis]
else:
tf.logging.info("x: {}, x_mask: {}".format(x.shape.rank, x_mask.shape.rank))
raise ValueError("Dimension not supported.")
mean = x * x_mask
return tf.reduce_sum(mean, axis=[1, 2]) # sum over L, C
def reduce_sum_over_l(x, x_mask):
"""Returns sum of x (over L) given the actual length and pad.
Args:
x: input. (B,L,C)
x_mask: binary padding mask. (B,L)
Returns:
sum of x. (B,C)
"""
if x.shape.rank == 3 and x_mask.shape.rank == 2:
x_mask = x_mask[..., tf.newaxis]
else:
tf.logging.info("x: {}, x_mask: {}".format(x.shape.rank, x_mask.shape.rank))
raise ValueError("Dimension not supported.")
mean = x * x_mask
return tf.reduce_sum(mean, axis=-2) # sum over L
def reduce_mean_over_l(x, x_mask):
"""Returns mean of x (over L) given the actual length and pad."""
return reduce_sum_over_l(x, x_mask) / tf.reduce_sum(x_mask, 1, keepdims=True)
def reduce_mean_over_bl(x, x_mask):
"""Returns average of x (over B and L) given the actual length and pad.
Args:
x: input. (B,L,C)
x_mask: binary padding mask. (B,L)
Returns:
mean of x. (C)
"""
if x.shape.rank == 3 and x_mask.shape.rank == 2:
x_mask = x_mask[..., tf.newaxis]
else:
tf.logging.info("x: {}, x_mask: {}".format(x.shape.rank, x_mask.shape.rank))
raise ValueError("Dimension not supported.")
mean = x * x_mask
mean = tf.reduce_sum(mean, axis=[0, 1]) # sum over B, L
return mean / tf.reduce_sum(x_mask)
def reduce_mean_over_l_sum_over_c(x, x_mask):
"""Returns mean of x over L and sum over C."""
mean = reduce_sum_over_lc(x, x_mask)
return mean / tf.reduce_sum(x_mask, 1)
def reduce_mean_over_bl_sum_over_c(x, x_mask):
"""Returns mean of x over B and L and sum over C."""
mean = reduce_mean_over_bl(x, x_mask)
return tf.reduce_sum(mean)
def moments_over_bl(x, x_mask):
"""Returns mean and var of x over B and L."""
mean = reduce_mean_over_bl(x, x_mask)
var = reduce_mean_over_bl((x-mean)**2, x_mask)
return mean, var
def standard_normal_density(x, x_mask, reduce_sum=False):
"""Return standard normal distribution with same shape as x."""
log_probs = -0.5 * (x**2 + math.log(math.pi * 2.0))
if reduce_sum:
log_probs = reduce_mean_over_bl_sum_over_c(log_probs, x_mask)
else:
log_probs = reduce_sum_over_lc(log_probs, x_mask)
return log_probs
def standard_normal(x, name="normal"):
"""Return standard normal distribution with same shape as x."""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
dist = tfp.distributions.Normal(
loc=tf.zeros_like(x),
scale=tf.ones_like(x),
allow_nan_stats=False)
return dist
def diagonal_normal(outputs, name="normal"):
"""Split outputs into mu and log_sigma and return z."""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
loc, log_scale = tf.split(outputs, 2, axis=-1)
scale = tf.exp(log_scale)
dist = tfp.distributions.Normal(
loc=loc,
scale=scale + tf.keras.backend.epsilon(),
allow_nan_stats=False)
return dist
def split_coupling(
x, x_mask, split_dim, identity_first, decoder_self_attention_bias):
"""Split function used in coupling flows."""
n_channels = common_layers.shape_list(x)[-1]
if split_dim == "c":
n_transform = n_identity = n_channels // 2
x_id = x[..., :n_identity] if identity_first else x[..., n_transform:]
x_tr = x[..., n_identity:] if identity_first else x[..., :n_transform]
bias, mask = decoder_self_attention_bias, x_mask
elif split_dim == "a":
n_transform = n_identity = n_channels // 2
x_id = x[..., 0::2] if identity_first else x[..., 1::2]
x_tr = x[..., 1::2] if identity_first else x[..., 0::2]
bias, mask = decoder_self_attention_bias, x_mask
elif split_dim == "t":
n_transform = n_identity = n_channels
x_id = x[:, 0::2, :] if identity_first else x[:, 1::2, :]
x_tr = x[:, 1::2, :] if identity_first else x[:, 0::2, :]
bias, mask = decoder_self_attention_bias[..., 0::2], x_mask[..., 0::2]
return x_id, x_tr, n_identity, n_transform, bias, mask
def join_coupling(z_id, z_tr, split_dim, identity_first):
"""Reverse split function used in coupling flows."""
assert z_id.shape.rank == 3 and z_tr.shape.rank == 3
result = [z_id, z_tr] if identity_first else [z_tr, z_id]
if split_dim == "c":
result = tf.concat(result, axis=2) # concat in the channel dimension
elif split_dim == "a":
result = tf.stack(result, axis=3) # stack in the channel dimension
elif split_dim == "t":
result = tf.stack(result, axis=2) # stack in the time dimension
return result
def assign(w, initial_value):
w = w.assign(initial_value)
with tf.control_dependencies([w]):
return w
def get_variable_ddi(
name, shape, value, init, initializer=None, dtype=tf.float32,
regularizer=None, trainable=True):
"""Wrapper for data-dependent initialization."""
kwargs = {"trainable": trainable}
if initializer:
kwargs["initializer"] = initializer
if regularizer:
kwargs["regularizer"] = regularizer
w = tf.get_variable(name, shape, dtype, **kwargs)
if isinstance(init, bool):
if init:
return assign(w, value)
return w
else:
return tf.cond(init, lambda: assign(w, value), lambda: w)
================================================
FILE: tensor2tensor/layers/transformer_glow_layers_ops_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.layers.transformer_flow_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensor2tensor.layers import transformer_glow_layers_ops as gops
from tensor2tensor.models import transformer
import tensorflow.compat.v1 as tf
BATCH_SIZE = 10
INPUT_LENGTH = 3
TARGET_LENGTH = 16
N_CHANNELS = 24
HIDDEN_SIZE = 64
N_1X1_HEADS = 4
class TransformerFlowOpsTest(parameterized.TestCase, tf.test.TestCase):
def get_data(self):
x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
mean=0.0, stddev=1.0)
x_lengths = np.random.randint(low=1, high=TARGET_LENGTH+1, size=BATCH_SIZE)
x_mask = tf.sequence_mask(x_lengths, maxlen=TARGET_LENGTH, dtype=tf.float32)
return x, x_mask
def get_hparams(self):
hparams = transformer.transformer_small()
hparams.add_hparam("prior_type", "affine")
hparams.add_hparam("depths", "12") # infer n_levels from depths
hparams.add_hparam("split_plans", "tca")
hparams.add_hparam("factor", 2) # squeezing factor
hparams.add_hparam("n_layers_transform_params", 1)
hparams.add_hparam("n_layers_multiscale_prior", 3)
hparams.add_hparam("flow_num_heads", 4)
hparams.add_hparam("flow_num_1x1_heads", N_1X1_HEADS)
hparams.add_hparam("flow_hidden_size", 64)
hparams.add_hparam("flow_filter_size", 128)
hparams.add_hparam("cond_prior_on_src", True)
hparams.add_hparam("bottom_prior_std", False)
hparams.add_hparam("latent_size", N_CHANNELS)
hparams.add_hparam("scale_width", 0.999)
hparams.add_hparam("coupling_transform_ratio", 0.5)
hparams.add_hparam("actnorm_type", "actnorm")
hparams.add_hparam("actnorm_weightnorm", True)
hparams.add_hparam("perm_type", "1x1")
hparams.add_hparam("init_permutation", True)
hparams.causal_decoder_self_attention = False
hparams.hidden_size = HIDDEN_SIZE
return hparams
def get_kwargs(self, hparams=None):
if hparams is None:
hparams = self.get_hparams()
encoder_output = tf.random.uniform(
(BATCH_SIZE, INPUT_LENGTH, HIDDEN_SIZE))
encoder_decoder_attention_bias = tf.random.uniform(
(BATCH_SIZE, 1, 1, INPUT_LENGTH))
decoder_self_attention_bias = tf.random.uniform(
(BATCH_SIZE, 1, 1, TARGET_LENGTH))
kwargs = {"hparams": hparams,
"encoder_output": encoder_output,
"encoder_decoder_attention_bias": encoder_decoder_attention_bias,
"decoder_self_attention_bias": decoder_self_attention_bias}
return kwargs
def test_dense_weightnorm(self):
x, x_mask = self.get_data()
x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, HIDDEN_SIZE),
mean=0.0, stddev=1.0)
y = gops.dense_weightnorm("wn", x, N_CHANNELS, x_mask,
init_scale=1.0, init=True)
y_nopad = tf.boolean_mask(y, x_mask)
mean, var = tf.nn.moments(y_nopad, axes=[0])
self.evaluate(tf.global_variables_initializer())
x, x_mask, y, y_nopad, mean, var = (
self.evaluate([x, x_mask, y, y_nopad, mean, var]))
self.assertEqual(y.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
self.assertTrue(np.allclose(mean, 0.0, atol=1e-5))
self.assertTrue(np.allclose(var, 1.0, atol=1e-5))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/transformer_glow_layers_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.layers.transformer_glow_layers.
1. Actnorm test (zero mean and unit variance).
2. Invertibility tests for:
* actnorm
* actnorm with weight normalization
* 1x1 invertible convolution
* multi-head 1x1 invertible convolution
* affine coupling
* split
* 1 step of flow
* k steps of flow
* entire pipeline (tested up to 3 levels, 32 steps: tca/tca/ca, 12/12/8)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from absl.testing import parameterized
import numpy as np
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import transformer_glow_layers as glow
from tensor2tensor.layers import transformer_glow_layers_ops as gops
from tensor2tensor.models import transformer
import tensorflow.compat.v1 as tf
BATCH_SIZE = 20
INPUT_LENGTH = 3
TARGET_LENGTH = 16
N_CHANNELS = 256
HIDDEN_SIZE = 64
N_1X1_HEADS = 4
DTYPE = tf.float32
def float32_bottleneck(x):
return tf.cast(tf.cast(x, tf.float32), tf.float64)
def get_diff(l1, l2):
l2 = l2[::-1]
for i1, i2 in zip(l1, l2):
print (i1 - i2)
for i1, i2 in zip(l1, l2):
print (np.max(np.abs(i1 - i2)))
class TransformerGlowLayersTest(parameterized.TestCase, tf.test.TestCase):
def get_hparams(self):
hparams = transformer.transformer_small()
hparams.add_hparam("prior_type", "affine")
hparams.add_hparam("factor", 2) # squeezing factor
hparams.add_hparam("n_layers_transform_params", 1)
hparams.add_hparam("n_1x1_heads", N_1X1_HEADS)
hparams.add_hparam("flow_num_1x1_heads", 4)
hparams.add_hparam("flow_num_heads", 4)
hparams.add_hparam("flow_hidden_size", 64)
hparams.add_hparam("flow_filter_size", 128)
hparams.add_hparam("flow_layer_prepostprocess_dropout", 0.0)
hparams.add_hparam("flow_attention_dropout", 0.0)
hparams.add_hparam("flow_relu_dropout", 0.0)
hparams.add_hparam("latent_size", N_CHANNELS)
hparams.add_hparam("use_weightnorm", True)
hparams.add_hparam("kl_startup_steps", 2000)
hparams.add_hparam("affine_scale", "glow")
hparams.add_hparam("scale_width", 0.999)
hparams.add_hparam("step_fn", "glow") # glow / chunting
hparams.add_hparam("conv_fn", "np") # np / tf
hparams.add_hparam("posterior_type", "diagonal_normal")
hparams.causal_decoder_self_attention = False
hparams.hidden_size = HIDDEN_SIZE
hparams.weight_dtype = "float32"
hparams.add_hparam("pos_attn", False)
return hparams
def get_data(self):
x = tf.random_normal(
(BATCH_SIZE, TARGET_LENGTH, N_CHANNELS), dtype=DTYPE)
x_lengths = np.random.randint(
low=1, high=TARGET_LENGTH+1, size=BATCH_SIZE)
x_lengths = np.ceil(x_lengths / 4.0) * 4.0
x_lengths = x_lengths.astype(int)
x_mask = tf.sequence_mask(x_lengths, maxlen=TARGET_LENGTH, dtype=DTYPE)
return x, x_mask, x_lengths
def get_kwargs(self, x_mask, hparams=None):
if hparams is None:
hparams = self.get_hparams()
encoder_output = tf.random.uniform(
(BATCH_SIZE, INPUT_LENGTH, HIDDEN_SIZE), dtype=DTYPE)
encoder_decoder_attention_bias = tf.zeros(
(BATCH_SIZE, 1, 1, INPUT_LENGTH), dtype=DTYPE)
decoder_self_attention_bias = 1.0 - x_mask[:, tf.newaxis, tf.newaxis, :]
decoder_self_attention_bias *= -1e9
kwargs = {"hparams": hparams,
"encoder_output": encoder_output,
"encoder_decoder_attention_bias": encoder_decoder_attention_bias,
"decoder_self_attention_bias": decoder_self_attention_bias}
return kwargs
def test_actnorm(self):
_, x_mask, _ = self.get_data()
x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
mean=50.0, stddev=10.0, dtype=DTYPE)
x_act, logabsdet = glow.actnorm(
"actnorm", x, x_mask, inverse=False, init=True)
x_act_nopad = tf.boolean_mask(x_act, x_mask)
x_mean, x_var = tf.nn.moments(x_act_nopad, axes=[0])
self.evaluate(tf.global_variables_initializer())
x, x_act, logabsdet, x_mean, x_var = (
self.evaluate([x, x_act, logabsdet, x_mean, x_var]))
self.assertEqual(x_act.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
self.assertEqual(logabsdet.shape, (BATCH_SIZE,))
self.assertTrue(np.allclose(x_mean, 0.0, atol=1e-5))
self.assertTrue(np.allclose(x_var, 1.0, atol=1e-5))
def test_actnorm_invertibility(self):
name = "actnorm"
x, x_mask, _ = self.get_data()
x_inv, logabsdet = glow.actnorm(
name, x, x_mask, inverse=False, init=False)
x_inv_inv, logabsdet_inv = glow.actnorm(
name, x_inv, x_mask, inverse=True, init=False)
self.evaluate(tf.global_variables_initializer())
x, x_inv, x_inv_inv, x_mask, logabsdet, logabsdet_inv = (
self.evaluate(
[x, x_inv, x_inv_inv, x_mask, logabsdet, logabsdet_inv]))
diff = x - x_inv_inv
logabsdet_sum = logabsdet + logabsdet_inv
self.assertEqual(x.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
self.assertEqual(x_inv.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
self.assertEqual(x_inv_inv.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
self.assertTrue(np.allclose(logabsdet_sum, 0.0, atol=1e-5))
@parameterized.parameters(
(glow.multihead_invertible_1x1_conv_np, "a"),
(glow.multihead_invertible_1x1_conv_np, "c"),
)
def test_multi_1x1_invertibility(
self, func, multihead_split):
name = "multi_1x1"
x, x_mask, _ = self.get_data()
x_inv, logabsdet = func(
name, x, x_mask, multihead_split, inverse=False, dtype=DTYPE)
x_inv_inv, logabsdet_inv = func(
name, x_inv, x_mask, multihead_split, inverse=True, dtype=DTYPE)
self.evaluate(tf.global_variables_initializer())
x, x_mask, x_inv, x_inv_inv, logabsdet, logabsdet_inv = (
self.evaluate(
[x, x_mask, x_inv, x_inv_inv, logabsdet, logabsdet_inv]))
diff = x - x_inv_inv
logabsdet_sum = logabsdet + logabsdet_inv
logabsdet_ = logabsdet / np.sum(x_mask, -1)
self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
self.assertTrue(np.allclose(logabsdet_, 0.0, atol=1e-5))
self.assertTrue(np.allclose(logabsdet_sum, 0.0, atol=1e-5))
@parameterized.parameters(
(glow.additive_coupling, "c"),
(glow.additive_coupling, "t"),
(glow.additive_coupling, "a"),
(glow.affine_coupling, "c"),
(glow.affine_coupling, "t"),
(glow.affine_coupling, "a"),
)
def test_coupling_invertibility(self, func, split_dim):
name = "affine"
x, x_mask, _ = self.get_data()
kwargs = self.get_kwargs(x_mask)
x_inv, logabsdet = func(
name, x, x_mask, split_dim=split_dim,
identity_first=True, inverse=False, init=False, disable_dropout=True,
**kwargs)
x_inv_inv, logabsdet_inv = func(
name, x_inv, x_mask, split_dim=split_dim,
identity_first=True, inverse=True, init=False, disable_dropout=True,
**kwargs)
self.evaluate(tf.global_variables_initializer())
x, x_mask, x_inv, x_inv_inv, logabsdet, logabsdet_inv = (
self.evaluate(
[x, x_mask, x_inv, x_inv_inv, logabsdet, logabsdet_inv]))
diff = x - x_inv_inv
logabsdet_sum = logabsdet + logabsdet_inv
self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
self.assertTrue(np.allclose(logabsdet_sum, 0.0, atol=1e-5))
def test_split(self):
x, x_mask, _ = self.get_data()
x_inv, z, log_p = glow.split(
"split", x, x_mask, inverse=False)
x_inv_inv, _, log_p_inv = glow.split(
"split", x_inv, x_mask, z=z, inverse=True)
self.evaluate(tf.global_variables_initializer())
x, x_inv, x_inv_inv, z, log_p, log_p_inv = self.evaluate(
[x, x_inv, x_inv_inv, z, log_p, log_p_inv])
diff = x - x_inv_inv
log_p_diff = log_p - log_p_inv
self.assertEqual(
x_inv.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS//2))
self.assertEqual(
z.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS//2))
self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
self.assertTrue(np.allclose(log_p_diff, 0.0, atol=1e-5))
def test_flow_invertibility(self):
name = "flow_step"
split_dims = "cat"
x, x_mask, _ = self.get_data()
kwargs = self.get_kwargs(x_mask)
x_inv, logabsdet = glow.flow_step_glow(
name, x, x_mask, split_dims, inverse=False, init=False, dtype=DTYPE,
disable_dropout=True, **kwargs)
x_inv_inv, logabsdet_inv = glow.flow_step_glow(
name, x_inv, x_mask, split_dims, inverse=True, init=False,
dtype=DTYPE, disable_dropout=True, **kwargs)
self.evaluate(tf.global_variables_initializer())
x, x_mask, x_inv, x_inv_inv, logabsdet, logabsdet_inv = (
self.evaluate(
[x, x_mask, x_inv, x_inv_inv, logabsdet, logabsdet_inv]))
diff = x - x_inv_inv
logabsdet_sum = logabsdet + logabsdet_inv
self.assertTrue(np.allclose(diff, 0.0, atol=2e-5))
self.assertTrue(np.allclose(logabsdet_sum, 0.0, atol=7e-5))
@parameterized.parameters(
("1", "cat", "affine"),
("1/1", "cat/cat", "affine"),
("1/1/1", "cat/cat/ca", "affine"),
)
def test_aaa_glow_training(self, depths, split_plans, prior_type):
with tf.Graph().as_default():
_, x_mask, _ = self.get_data()
x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
mean=10.0, stddev=3.0, dtype=DTYPE)
bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask)
hparams = self.get_hparams()
hparams.prior_type = prior_type
hparams.depths = depths
hparams.split_plans = split_plans
n_levels = len(hparams.depths.split("/"))
kwargs = self.get_kwargs(x_mask, hparams)
_ = kwargs.pop("decoder_self_attention_bias")
x_inv, _, _, _ = glow.glow(
"glow", x, x_mask, bias, inverse=False, init=True,
disable_dropout=True, **kwargs)
curr_dir = tempfile.mkdtemp()
model_path = os.path.join(curr_dir, "model")
with tf.Session() as session:
saver = tf.train.Saver()
session.run(tf.global_variables_initializer())
session.run(x_inv)
saver.save(session, model_path)
with tf.Graph().as_default():
_, x_mask, _ = self.get_data()
x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
mean=10.0, stddev=3.0, dtype=DTYPE)
bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask)
hparams = self.get_hparams()
hparams.depths = depths
hparams.split_plans = split_plans
kwargs = self.get_kwargs(x_mask, hparams)
_ = kwargs.pop("decoder_self_attention_bias")
log_q_z = gops.standard_normal_density(x, x_mask)
log_q_z = tf.reduce_sum(log_q_z) / tf.reduce_sum(x_mask)
x_inv, logabsdets, log_ps, zs = glow.glow(
"glow", x, x_mask, bias, inverse=False, init=False,
disable_dropout=True, **kwargs)
x_inv_inv, logabsdets_inv, log_ps_inv, _ = glow.glow(
"glow", x_inv, x_mask, bias, inverse=True, split_zs=zs, init=False,
disable_dropout=True, **kwargs)
logabsdets = tf.reduce_sum(
logabsdets, axis=0) / tf.reduce_sum(x_mask)
logabsdets_inv = tf.reduce_sum(
logabsdets_inv, axis=0) / tf.reduce_sum(x_mask)
log_ps = tf.reduce_sum(log_ps, axis=0) / tf.reduce_sum(x_mask)
log_ps_inv = tf.reduce_sum(log_ps_inv, axis=0) / tf.reduce_sum(x_mask)
with tf.Session() as session:
saver = tf.train.Saver()
saver.restore(session, model_path)
(x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps,
logabsdets_inv, log_ps_inv) = session.run([
x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps,
logabsdets_inv, log_ps_inv])
diff = x - x_inv_inv
log_ps_diff = log_ps - log_ps_inv
logabsdets_sum = logabsdets + logabsdets_inv
self.assertEqual(
x_inv.shape,
(BATCH_SIZE, TARGET_LENGTH//(2**(n_levels-1)), N_CHANNELS))
print (np.max(np.abs(diff)))
print (np.max(np.abs(log_ps_diff)))
print (np.max(np.abs(logabsdets_sum)))
self.assertTrue(np.allclose(diff, 0.0, atol=1e-4),
msg=np.max(np.abs(diff)))
self.assertTrue(np.allclose(log_ps_diff, 0.0, atol=1e-4),
msg=np.max(np.abs(log_ps_diff)))
self.assertTrue(np.allclose(logabsdets_sum, 0.0, atol=1e-4),
msg=np.max(np.abs(logabsdets_sum)))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/transformer_layers.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Commonly re-used transformer layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import expert_utils
from tensor2tensor.utils import mlperf_log
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
# TODO(lukaszkaiser): remove this function when not needed any more.
def layers():
return common_layers.layers()
def transformer_prepare_encoder(inputs, target_space, hparams, features=None,
type_ids=None, num_types=None,
reuse_target_embedding=tf.AUTO_REUSE):
"""Prepare one shard of the model for the encoder.
Args:
inputs: a Tensor.
target_space: a Tensor.
hparams: run hyperparameters
features: optionally pass the entire features dictionary as well.
This is needed now for "packed" datasets.
type_ids: optional, an int64 Tensor of shape [batch, length] that allows
for adding type embeddings, similar to positional embeddings.
num_types: optional, an int that decides the number of types in type_ids.
reuse_target_embedding: option to reuse variable name in the case that
symbol modalities are reused between inputs/targets.
Returns:
encoder_input: a Tensor, bottom of encoder stack
encoder_self_attention_bias: a bias tensor for use in encoder self-attention
encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
attention
"""
ishape_static = inputs.shape.as_list()
encoder_input = inputs
if features and "inputs_segmentation" in features:
# Packed dataset. Keep the examples from seeing each other.
inputs_segmentation = features["inputs_segmentation"]
inputs_position = features["inputs_position"]
targets_segmentation = features["targets_segmentation"]
if (hasattr(hparams, "unidirectional_encoder") and
hparams.unidirectional_encoder):
tf.logging.info("Using unidirectional encoder")
encoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(
common_layers.shape_list(inputs)[1]))
else:
encoder_self_attention_bias = (
common_attention.attention_bias_same_segment(
inputs_segmentation, inputs_segmentation))
encoder_decoder_attention_bias = (
common_attention.attention_bias_same_segment(targets_segmentation,
inputs_segmentation))
else:
encoder_padding = common_attention.embedding_to_padding(encoder_input)
ignore_padding = common_attention.attention_bias_ignore_padding(
encoder_padding)
if (hasattr(hparams, "unidirectional_encoder") and
hparams.unidirectional_encoder):
tf.logging.info("Using unidirectional encoder")
encoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(
common_layers.shape_list(inputs)[1]))
else:
# Usual case - not a packed dataset.
encoder_self_attention_bias = ignore_padding
encoder_decoder_attention_bias = ignore_padding
inputs_position = None
if hparams.proximity_bias:
encoder_self_attention_bias += common_attention.attention_bias_proximal(
common_layers.shape_list(inputs)[1])
if target_space is not None and hparams.get("use_target_space_embedding",
True):
# Append target_space_id embedding to inputs.
emb_target_space = common_layers.embedding(
target_space,
32,
ishape_static[-1],
name="target_space_embedding",
dtype=hparams.get("activation_dtype", "float32"),
reuse=reuse_target_embedding)
emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
encoder_input += emb_target_space
if hparams.pos == "timing":
if inputs_position is not None:
encoder_input = common_attention.add_timing_signal_1d_given_position(
encoder_input, inputs_position)
else:
encoder_input = common_attention.add_timing_signal_1d(encoder_input)
elif hparams.pos == "timing_from_features":
encoder_input = common_attention.add_timing_signals_from_features(
encoder_input, features, hparams.position_features)
elif hparams.pos == "emb":
encoder_input = common_attention.add_positional_embedding(
encoder_input, hparams.max_length, "inputs_positional_embedding",
inputs_position)
# Add type embeddings
if type_ids is not None:
if not num_types:
raise ValueError("Need to set num_types as well.")
encoder_input = common_attention.add_positional_embedding(
encoder_input, num_types, "inputs_type_embedding", type_ids)
encoder_self_attention_bias = common_layers.cast_like(
encoder_self_attention_bias, encoder_input)
encoder_decoder_attention_bias = common_layers.cast_like(
encoder_decoder_attention_bias, encoder_input)
return (encoder_input, encoder_self_attention_bias,
encoder_decoder_attention_bias)
def transformer_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
name="encoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None,
attn_bias_for_padding=None):
"""A stack of transformer layers.
Args:
encoder_input: a Tensor
encoder_self_attention_bias: bias Tensor for self-attention
(see common_attention.attention_bias())
hparams: hyperparameters for model
name: a string
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This must either be
passed in, which we do for "packed" datasets, or inferred from
encoder_self_attention_bias. The knowledge about padding is used
for pad_remover(efficiency) and to mask out padding in convolutional
layers.
save_weights_to: an optional dictionary to capture attention weights
for visualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: optional list onto which to append extra training losses
attn_bias_for_padding: Padded attention bias in case a unidirectional
encoder is being used where future attention is masked.
Returns:
y: a Tensors
"""
x = encoder_input
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
value=hparams.num_encoder_layers or hparams.num_hidden_layers)
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
value=hparams.attention_dropout)
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
value={
"use_bias": "false",
"num_heads": hparams.num_heads,
"hidden_size": hparams.hidden_size
})
with tf.variable_scope(name):
if nonpadding is not None:
padding = 1.0 - nonpadding
else:
attention_bias = encoder_self_attention_bias
if attn_bias_for_padding is not None:
attention_bias = attn_bias_for_padding
padding = common_attention.attention_bias_to_padding(attention_bias)
nonpadding = 1.0 - padding
pad_remover = None
if hparams.use_pad_remover and not common_layers.is_xla_compiled():
pad_remover = expert_utils.PadRemover(padding)
for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
with tf.variable_scope("self_attention"):
if layer < hparams.get("num_area_layers", 0):
max_area_width = hparams.get("max_area_width", 1)
max_area_height = hparams.get("max_area_height", 1)
memory_height = hparams.get("memory_height", 1)
else:
max_area_width = 1
max_area_height = 1
memory_height = 1
y = common_attention.multihead_attention(
common_layers.layer_preprocess(x, hparams),
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"),
hard_attention_k=hparams.get("hard_attention_k", 0),
gumbel_noise_weight=hparams.get("gumbel_noise_weight", 0.0),
max_area_width=max_area_width,
max_area_height=max_area_height,
memory_height=memory_height,
area_key_mode=hparams.get("area_key_mode", "none"),
area_value_mode=hparams.get("area_value_mode", "none"),
training=(hparams.get("mode", tf_estimator.ModeKeys.TRAIN)
== tf_estimator.ModeKeys.TRAIN))
x = common_layers.layer_postprocess(x, y, hparams)
with tf.variable_scope("ffn"):
y = transformer_ffn_layer(
common_layers.layer_preprocess(x, hparams),
hparams,
pad_remover,
conv_padding="SAME",
nonpadding_mask=nonpadding,
losses=losses)
x = common_layers.layer_postprocess(x, y, hparams)
# if normalization is done in layer_preprocess, then it should also be done
# on the output, since the output can grow very large, being the sum of
# a whole stack of unnormalized layer outputs.
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_NORM,
value={"hidden_size": hparams.hidden_size})
return common_layers.layer_preprocess(x, hparams)
def transformer_ffn_layer(x,
hparams,
pad_remover=None,
conv_padding="LEFT",
nonpadding_mask=None,
losses=None,
cache=None,
decode_loop_step=None,
readout_filter_size=0,
layer_collection=None):
"""Feed-forward layer in the transformer.
Args:
x: a Tensor of shape [batch_size, length, hparams.hidden_size]
hparams: hyperparameters for model
pad_remover: an expert_utils.PadRemover object tracking the padding
positions. If provided, when using convolutional settings, the padding
is removed before applying the convolution, and restored afterward. This
can give a significant speedup.
conv_padding: a string - either "LEFT" or "SAME".
nonpadding_mask: an optional Tensor with shape [batch_size, length].
needed for convolutional layers with "SAME" padding.
Contains 1.0 in positions corresponding to nonpadding.
losses: optional list onto which to append extra training losses
cache: dict, containing tensors which are the results of previous
attentions, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop.
Only used for inference on TPU.
readout_filter_size: if it's greater than 0, then it will be used instead of
filter_size
layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
KFAC optimizer. Default is None.
Returns:
a Tensor of shape [batch_size, length, hparams.hidden_size]
Raises:
ValueError: If losses arg is None, but layer generates extra losses.
"""
ffn_layer = hparams.ffn_layer
relu_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "relu_dropout_broadcast_dims", "")))
if ffn_layer == "conv_hidden_relu":
# Backwards compatibility
ffn_layer = "dense_relu_dense"
if ffn_layer == "dense_relu_dense":
# In simple convolution mode, use `pad_remover` to speed up processing.
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
value={
"filter_size": hparams.filter_size,
"use_bias": "True",
"activation": mlperf_log.RELU
})
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
value={
"hidden_size": hparams.hidden_size,
"use_bias": "True",
})
mlperf_log.transformer_print(
key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout)
if pad_remover:
original_shape = common_layers.shape_list(x)
# Collapse `x` across examples, and remove padding positions.
x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
x = tf.expand_dims(pad_remover.remove(x), axis=0)
conv_output = common_layers.dense_relu_dense(
x,
hparams.filter_size,
hparams.hidden_size,
dropout=hparams.relu_dropout,
dropout_broadcast_dims=relu_dropout_broadcast_dims,
layer_collection=layer_collection)
if pad_remover:
# Restore `conv_output` to the original shape of `x`, including padding.
conv_output = tf.reshape(
pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape)
return conv_output
elif ffn_layer == "conv_relu_conv":
return common_layers.conv_relu_conv(
x,
readout_filter_size or hparams.filter_size,
hparams.hidden_size,
first_kernel_size=hparams.conv_first_kernel,
second_kernel_size=1,
padding=conv_padding,
nonpadding_mask=nonpadding_mask,
dropout=hparams.relu_dropout,
cache=cache,
decode_loop_step=decode_loop_step)
elif ffn_layer == "parameter_attention":
return common_attention.parameter_attention(
x, hparams.parameter_attention_key_channels or hparams.hidden_size,
hparams.parameter_attention_value_channels or hparams.hidden_size,
hparams.hidden_size, readout_filter_size or hparams.filter_size,
hparams.num_heads,
hparams.attention_dropout)
elif ffn_layer == "conv_hidden_relu_with_sepconv":
return common_layers.conv_hidden_relu(
x,
readout_filter_size or hparams.filter_size,
hparams.hidden_size,
kernel_size=(3, 1),
second_kernel_size=(31, 1),
padding="LEFT",
dropout=hparams.relu_dropout)
elif ffn_layer == "sru":
return common_layers.sru(x)
elif ffn_layer == "local_moe_tpu":
overhead = hparams.moe_overhead_eval
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
overhead = hparams.moe_overhead_train
ret, loss = expert_utils.local_moe_tpu(
x,
hparams.filter_size // 2,
hparams.hidden_size,
hparams.moe_num_experts,
overhead=overhead,
loss_coef=hparams.moe_loss_coef)
elif ffn_layer == "local_moe":
overhead = hparams.moe_overhead_eval
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
overhead = hparams.moe_overhead_train
ret, loss = expert_utils.local_moe(
x,
True,
expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size],
hparams.hidden_size),
hparams.moe_num_experts,
k=hparams.moe_k,
hparams=hparams)
losses.append(loss)
return ret
else:
assert ffn_layer == "none"
return x
================================================
FILE: tensor2tensor/layers/transformer_memory.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The memory unit for Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_layers
import tensorflow.compat.v1 as tf
class RecurrentMemory(object):
"""Base class for recurrent memory.
This class defines the memory interface, but behaves like a no-op.
"""
def pre_attention(self, segment, query_antecedent, memory_antecedent, bias):
"""Called prior to self-attention, to incorporate memory items.
Args:
segment: an integer Tensor with shape [batch]
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: must be None. Attention normally allows this to be a
Tensor with shape [batch, length_m, channels], but we currently only
support memory for decoder-side self-attention.
bias: bias Tensor (see attention_bias())
Returns:
(data, new_query_antecedent, new_memory_antecedent, new_bias)
"""
del segment
return None, query_antecedent, memory_antecedent, bias
def post_attention(self, token, x):
"""Called after self-attention. The memory can be updated here.
Args:
token: Data returned by pre_attention, which can be used to carry over
state related to the current memory operation.
x: a Tensor of data after self-attention and feed-forward
Returns:
a (possibly modified) version of the input x
"""
assert token is None
return x
class RecentTokensMemory(RecurrentMemory):
"""A memory module that caches features for recent tokens.
When the number of tokens cached is equal to the chunk size, this is
equivalent to the memory used by Transformer-XL
(https://arxiv.org/abs/1901.02860)
"""
def __init__(self, name, hparams):
hidden_size = hparams.hidden_size
self.chunk_length = hparams.split_targets_chunk_length
assert self.chunk_length > 0, "Chunking is required to use recurrent memory"
if hasattr(hparams, "num_memory_items") and hparams.num_memory_items > 0:
self.tokens_to_cache = hparams.num_memory_items
else:
self.tokens_to_cache = self.chunk_length
# TODO(kitaev): The implementation of the chunking code makes it somewhat
# convoluted to figure out how many actual sequences we can have per batch.
# The data pipeline should be revisited at some point.
if (hasattr(hparams, "recurrent_memory_batch_size")
and hparams.recurrent_memory_batch_size > 0):
batch_size_in_sequences = hparams.recurrent_memory_batch_size
else:
batch_size_in_sequences = hparams.batch_size / hparams.max_length
memory_shape = [batch_size_in_sequences, self.tokens_to_cache, hidden_size]
bias_shape = [batch_size_in_sequences, 1, 1, self.tokens_to_cache]
with tf.variable_scope(name):
self.previous_segment = tf.get_variable(
"memsegment", (batch_size_in_sequences,),
dtype=tf.int32, trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.constant_initializer(0))
self.previous_vals = tf.get_variable(
"memvals", memory_shape,
dtype=tf.float32, trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.constant_initializer(.0))
self.previous_bias = tf.get_variable(
"membias", bias_shape,
dtype=tf.float32, trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.constant_initializer(-1e9))
def pre_attention(self, segment, query_antecedent, memory_antecedent, bias):
"""Called prior to self-attention, to incorporate memory items.
Args:
segment: an integer Tensor with shape [batch]
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: must be None. Attention normally allows this to be a
Tensor with shape [batch, length_m, channels], but we currently only
support memory for decoder-side self-attention.
bias: bias Tensor (see attention_bias())
Returns:
(data, new_query_antecedent, new_memory_antecedent, new_bias)
"""
assert memory_antecedent is None, "We only support language modeling"
# In eval mode, batch size may be variable
memory_batch_size = tf.shape(self.previous_vals)[0]
current_batch_size = tf.shape(query_antecedent)[0]
amount_to_pad = memory_batch_size - current_batch_size
# If segment id is zero, don't attend back to the memory
previous_bias = self.previous_bias[:current_batch_size, :, :, :] + tf.cast(
tf.equal(segment[:, None, None, None], 0), tf.float32) * -1e9
sliced_previous_vals = self.previous_vals[:current_batch_size, :, :]
new_memory_antecedent = tf.concat(
[tf.stop_gradient(sliced_previous_vals), query_antecedent], 1)
new_bias = tf.concat([
tf.tile(tf.stop_gradient(previous_bias), [1, 1, self.chunk_length, 1]),
tf.tile(bias, [current_batch_size, 1, 1, 1]),
], -1)
remember_segment = tf.pad(segment, [[0, amount_to_pad]])
# TODO(kitaev): The code assumes that we always either increment the chunk
# number or reset it to zero. This assumption will not hold if we re-run the
# model for each token, e.g. for autoregressive greedy/beam/sampling decode.
remember_vals = tf.pad(query_antecedent,
[[0, amount_to_pad], [0, 0], [0, 0]])
# Query position is on axis -2 for bias: as long as a token can be attended
# to from at least one query position (i.e. it's not padding), memorize it.
remember_bias = tf.tile(
tf.reduce_max(bias, -2, keepdims=True), [memory_batch_size, 1, 1, 1])
# Assume that query_antecedent is always a full chunk (i.e. not truncated)
if self.chunk_length < self.tokens_to_cache:
remember_vals = tf.concat([self.previous_vals, remember_vals], 1)
remember_bias = tf.concat([
self.previous_bias - 1e9 * tf.cast(
tf.equal(
tf.pad(segment, [[0, amount_to_pad]])[:, None, None, None],
0), tf.float32),
remember_bias
], -1)
if self.chunk_length != self.tokens_to_cache:
remember_vals = remember_vals[:, -self.tokens_to_cache:, :]
remember_bias = remember_bias[:, :, :, -self.tokens_to_cache:]
token = (remember_segment, remember_vals, remember_bias)
return token, query_antecedent, new_memory_antecedent, new_bias
def post_attention(self, token, x):
"""Called after self-attention. The memory can be updated here.
Args:
token: Data returned by pre_attention, which can be used to carry over
state related to the current memory operation.
x: a Tensor of data after self-attention and feed-forward
Returns:
a (possibly modified) version of the input x
"""
with tf.control_dependencies([
self.previous_segment.assign(token[0]),
self.previous_vals.assign(token[1]),
self.previous_bias.assign(token[2]),
]):
return tf.identity(x)
class TransformerMemory(object):
"""Implements the Memory module.
Based on Neural Turing Machines: arXiv:1410.5401 [cs.NE]
"""
def __init__(self, batch_size, key_depth, val_depth, memory_size,
sharpen_factor=1., name="neural_memory"):
"""Initialize the memory object.
Args:
batch_size: the batch size.
key_depth: the depth of the memory keys.
val_depth: the depth of the memory values.
memory_size: the number of items in the memory.
sharpen_factor: the sharpen_factor for addressing the memory.
name: the optional variable scope.
"""
self.name = name
self.batch_size = batch_size
self.key_depth = key_depth
self.val_depth = val_depth
self.memory_size = memory_size
self.sharpen_factor = sharpen_factor
with tf.variable_scope(name):
self.segment_number = tf.get_variable(
"segment_number", [self.batch_size],
dtype=tf.int32, trainable=False,
initializer=tf.constant_initializer(100000))
self.mem_vals = tf.get_variable(
"memvals", [self.batch_size, self.memory_size, self.val_depth],
dtype=tf.float32, trainable=False,
initializer=tf.constant_initializer(.0))
self.mean_logits = tf.get_variable(
"meanlogits", [self.batch_size, self.memory_size],
dtype=tf.float32, trainable=False,
initializer=tf.constant_initializer(.0))
def _norm(self, x):
"""Compute the safe norm."""
return tf.sqrt(tf.reduce_sum(tf.square(x), keepdims=True, axis=-1) + 1e-7)
def _address_content(self, x):
"""Address the memory based on content similarity.
Args:
x: a tensor in the shape of [batch_size, length, depth].
Returns:
the logits for each memory entry [batch_size, length, memory_size].
"""
mem_keys = tf.layers.dense(self.mem_vals, self.key_depth,
bias_initializer=tf.constant_initializer(1.0),
name="mem_key")
mem_query = tf.layers.dense(x, self.key_depth,
bias_initializer=tf.constant_initializer(1.0),
name="mem_query")
norm = tf.matmul(self._norm(mem_query), self._norm(mem_keys),
transpose_b=True)
dot_product = tf.matmul(mem_query, mem_keys, transpose_b=True)
cos_dist = tf.div(dot_product, norm + 1e-7, name="cos_dist")
access_logits = self.sharpen_factor * cos_dist
return access_logits
def read(self, x):
"""Read from the memory.
An external component can use the results via a simple MLP,
e.g., fn(x W_x + retrieved_mem W_m).
Args:
x: a tensor in the shape of [batch_size, length, depth].
Returns:
access_logits: the logits for accessing the memory in shape of
[batch_size, length, memory_size].
retrieved_mem: the retrieved results in the shape of
[batch_size, length, val_depth].
"""
access_logits = self._address_content(x)
weights = tf.nn.softmax(access_logits)
retrieved_mem = tf.reduce_sum(
tf.multiply(tf.expand_dims(weights, 3),
tf.expand_dims(self.mem_vals, axis=1)), axis=2)
return access_logits, retrieved_mem
def write(self, x, access_logits):
"""Write to the memory based on a combination of similarity and least used.
Based on arXiv:1607.00036v2 [cs.LG].
Args:
x: a tensor in the shape of [batch_size, length, depth].
access_logits: the logits for accessing the memory.
Returns:
the update op.
"""
gamma = tf.layers.dense(x, 1, activation=tf.sigmoid, name="gamma")
write_logits = access_logits - gamma * tf.expand_dims(self.mean_logits, 1)
candidate_value = tf.layers.dense(x, self.val_depth,
activation=tf.nn.relu,
name="candidate_value")
erase_gates = tf.layers.dense(x, self.memory_size,
activation=tf.nn.sigmoid,
name="erase")
write_weights = tf.nn.softmax(write_logits)
erase_weights = tf.expand_dims(1 - erase_gates * write_weights, 3)
erase = tf.multiply(erase_weights,
tf.expand_dims(self.mem_vals, 1))
addition = tf.multiply(
tf.expand_dims(write_weights, 3),
tf.expand_dims(candidate_value, 2))
update_value_op = self.mem_vals.assign(
tf.reduce_mean(erase + addition, axis=1))
with tf.control_dependencies([update_value_op]):
write_op = self.mean_logits.assign(
self.mean_logits * 0.1 + tf.reduce_mean(write_logits * 0.9, axis=1))
return write_op
def set(self, mem_vals, mean_logits):
set_op = tf.group([
self.mem_vals.assign(mem_vals),
self.mean_logits.assign(mean_logits)])
return set_op
def get(self):
return self.mem_vals, self.mean_logits
def update_segment_number(self, segment_number):
return self.segment_number.assign(segment_number)
def reset(self, entries_to_reset):
"""Reset the entries in the memory.
Args:
entries_to_reset: a 1D tensor.
Returns:
the reset op.
"""
num_updates = tf.size(entries_to_reset)
update_vals = tf.scatter_update(
self.mem_vals, entries_to_reset,
tf.tile(tf.expand_dims(
tf.fill([self.memory_size, self.val_depth], .0), 0),
[num_updates, 1, 1]))
update_logits = tf.scatter_update(
self.mean_logits, entries_to_reset,
tf.tile(tf.expand_dims(
tf.fill([self.memory_size], .0), 0),
[num_updates, 1]))
reset_op = tf.group([update_vals, update_logits])
return reset_op
def pre_attention(self, segment_number, query_antecedent,
memory_antecedent, bias):
"""Called prior to self-attention, to incorporate memory items.
Args:
segment_number: an integer Tensor with shape [batch]
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: must be None. Attention normally allows this to be a
Tensor with shape [batch, length_m, channels], but we currently only
support memory for decoder-side self-attention.
bias: bias Tensor (see attention_bias())
Returns:
(data, new_query_antecedent, new_memory_antecedent, new_bias)
"""
with tf.variable_scope(self.name + "/pre_attention", reuse=tf.AUTO_REUSE):
assert memory_antecedent is None, "We only support language modeling"
with tf.control_dependencies([
tf.assert_greater_equal(self.batch_size, tf.size(segment_number))]):
difference = self.batch_size - tf.size(segment_number)
segment_number = tf.pad(segment_number, [[0, difference]])
reset_op = self.reset(tf.reshape(tf.where(
tf.less(segment_number, self.segment_number)), [-1]))
memory_results = {}
with tf.control_dependencies([reset_op]):
with tf.control_dependencies([
self.update_segment_number(segment_number)]):
x = tf.pad(query_antecedent, [
[0, difference], [0, 0], [0, 0]])
access_logits, retrieved_mem = self.read(x)
memory_results["x"] = x
memory_results["access_logits"] = access_logits
memory_results["retrieved_mem"] = retrieved_mem
return memory_results, query_antecedent, memory_antecedent, bias
def post_attention(self, token, x):
"""Called after self-attention. The memory can be updated here.
Args:
token: Data returned by pre_attention, which can be used to carry over
state related to the current memory operation.
x: a Tensor of data after self-attention and feed-forward
Returns:
a (possibly modified) version of the input x
"""
with tf.variable_scope(self.name + "/post_attention", reuse=tf.AUTO_REUSE):
depth = common_layers.shape_list(x)[-1]
actual_batch_size = common_layers.shape_list(x)[0]
memory_output = tf.gather(token["retrieved_mem"],
tf.range(actual_batch_size))
output = tf.add(tf.layers.dense(x, depth, use_bias=False),
tf.layers.dense(memory_output, depth))
with tf.control_dependencies([output]):
with tf.control_dependencies([
self.write(token["x"], token["access_logits"])]):
return tf.identity(output)
================================================
FILE: tensor2tensor/layers/transformer_memory_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor2tensor.layers.transformer_memory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensor2tensor.layers import transformer_memory
import tensorflow.compat.v1 as tf
class TransformerMemoryTest(parameterized.TestCase, tf.test.TestCase):
def testRead(self):
batch_size = 2
key_depth = 3
val_depth = 5
memory_size = 4
window_size = 6
x_depth = 10
memory = transformer_memory.TransformerMemory(
batch_size, key_depth, val_depth, memory_size)
x = tf.random_uniform([batch_size, window_size, x_depth], minval=1.0)
vals = tf.random_uniform([batch_size, memory_size, val_depth], minval=1.0)
logits = tf.random_uniform([batch_size, memory_size], minval=1.0)
update_op = memory.set(vals, logits)
with tf.control_dependencies([update_op]):
logits, retrieved_values = memory.read(x)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
logits_values, values = session.run([logits, retrieved_values])
self.assertAllEqual([batch_size, window_size, memory_size],
logits_values.shape)
self.assertAllEqual([batch_size, window_size, val_depth], values.shape)
def testWrite(self):
batch_size = 2
key_depth = 3
val_depth = 5
memory_size = 4
window_size = 6
x_depth = 10
memory = transformer_memory.TransformerMemory(
batch_size, key_depth, val_depth, memory_size)
x = tf.random_uniform([batch_size, window_size, x_depth], minval=1.0)
vals = tf.random_uniform([batch_size, memory_size, val_depth], minval=1.0)
logits = tf.random_uniform([batch_size, memory_size], minval=1.0)
update_op = memory.set(vals, logits)
with tf.control_dependencies([update_op]):
logits, _ = memory.read(x)
write_op = memory.write(x, logits)
mem_vals, mem_logits = memory.get()
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(write_op)
updated_vals, updated_logits = session.run([mem_vals, mem_logits])
self.assertAllEqual([batch_size, memory_size, val_depth],
updated_vals.shape)
self.assertAllEqual([batch_size, memory_size], updated_logits.shape)
def testReset(self):
batch_size = 2
key_depth = 3
val_depth = 5
memory_size = 4
memory = transformer_memory.TransformerMemory(
batch_size, key_depth, val_depth, memory_size)
vals = tf.random_uniform([batch_size, memory_size, val_depth], minval=1.0)
logits = tf.random_uniform([batch_size, memory_size], minval=1.0)
update_op = memory.set(vals, logits)
reset_op = memory.reset([1])
mem_vals, mem_logits = memory.get()
assert_op1 = tf.assert_equal(mem_vals[0], vals[0])
assert_op2 = tf.assert_equal(mem_logits[0], logits[0])
with tf.control_dependencies([assert_op1, assert_op2]):
all_zero1 = tf.reduce_sum(tf.abs(mem_vals[1]))
all_zero2 = tf.reduce_sum(tf.abs(mem_logits[1]))
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(update_op)
session.run(reset_op)
zero1, zero2 = session.run([all_zero1, all_zero2])
self.assertAllEqual(0, zero1)
self.assertAllEqual(0, zero2)
def testLoss(self):
batch_size = 2
key_depth = 5
val_depth = 5
memory_size = 4
window_size = 3
x_depth = 5
memory = transformer_memory.TransformerMemory(
batch_size, key_depth, val_depth, memory_size)
x = tf.random_uniform([batch_size, window_size, x_depth], minval=.0)
memory_results, _, _, _ = (
memory.pre_attention(
tf.random_uniform([batch_size], minval=0, maxval=1, dtype=tf.int32),
x, None, None))
x = memory.post_attention(memory_results, x)
with tf.control_dependencies([tf.print("x", x)]):
is_nan = tf.reduce_any(tf.math.is_nan(x))
with self.test_session() as session:
session.run(tf.global_variables_initializer())
for _ in range(100):
is_nan_value, _ = session.run([is_nan, x])
self.assertEqual(is_nan_value, False)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/layers/vq_discrete.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Clean discrete bottleneck as in https://arxiv.org/abs/1805.11063."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
from tensor2tensor.layers import common_layers
import tensorflow.compat.v1 as tf
from tensorflow.python.training import moving_averages
class DiscreteBottleneck(object):
"""Discrete bottleneck class."""
def __init__(self, hparams):
self.hparams = hparams
print ("self.hparams.z_size", self.hparams.z_size)
# Set the discretization bottleneck specific things here
self.hparams.z_size_per_residual = self.hparams.z_size // \
self.hparams.num_residuals
print ("self.hparams.num_residuals", self.hparams.num_residuals)
self.hparams.block_dim = int(
self.hparams.hidden_size // self.hparams.num_blocks)
self.hparams.block_v_size = 2**(
self.hparams.z_size_per_residual / self.hparams.num_blocks)
self.hparams.block_v_size = int(self.hparams.block_v_size)
self.means = tf.get_variable(
name="means",
shape=[
self.hparams.num_blocks, self.hparams.block_v_size,
self.hparams.block_dim
],
initializer=tf.initializers.variance_scaling(distribution="uniform"))
# Create the shadow variables if we are using EMA
if self.hparams.ema:
self.ema_count = tf.get_variable(
"ema_count", [self.hparams.num_blocks, self.hparams.block_v_size],
initializer=tf.constant_initializer(0),
trainable=False)
with tf.colocate_with(self.means):
self.ema_means = tf.get_variable(
"ema_means",
initializer=self.means.initialized_value(),
trainable=False)
def slice_hidden(self, x):
"""Slice encoder hidden state into block_dim.
Args:
x: Encoder hidden state of shape [-1, hidden_size].
Returns:
Sliced states of shape [-1, num_blocks, block_dim].
"""
x_sliced = tf.reshape(
x, shape=[-1, self.hparams.num_blocks, self.hparams.block_dim])
return x_sliced
def nearest_neighbor(self, x, means):
"""Find the nearest element in means to elements in x.
Args:
x: Batch of encoder continuous latent states sliced/projected into
shape [-1, num_blocks, block_dim].
means: Embedding means of shape.
Returns:
Tensor with nearest element in mean encoded in one-hot notation.
"""
x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keep_dims=True)
means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keep_dims=True)
scalar_prod = tf.matmul(
tf.transpose(x, perm=[1, 0, 2]), tf.transpose(means, perm=[0, 2, 1]))
scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2])
dist = x_norm_sq + tf.transpose(
means_norm_sq, perm=[2, 0, 1]) - 2 * scalar_prod
if self.hparams.soft_em:
nearest_idx = tf.stack(
[
tf.multinomial(
-dist[:, i, :], num_samples=self.hparams.num_samples)
for i in range(self.hparams.num_blocks)
],
axis=1)
nearest_hot = tf.one_hot(nearest_idx, depth=self.hparams.block_v_size)
nearest_hot = tf.reduce_mean(nearest_hot, axis=-2)
else:
if self.hparams.random_top_k > 1:
_, top_k_idx = tf.nn.top_k(-dist, k=self.hparams.random_top_k)
nearest_idx = tf.gather(
top_k_idx,
tf.random_uniform(
[1],
minval=0,
maxval=self.hparams.random_top_k - 1,
dtype=tf.int32),
axis=-1)
else:
if self.hparams.use_scales:
dist /= tf.reshape(self.hparams.scales,
[1, 1, self.hparams.moe_num_experts])
nearest_idx = tf.argmax(-dist, axis=-1)
nearest_hot = tf.one_hot(nearest_idx, self.hparams.block_v_size)
return nearest_hot
def embedding_lookup(self, x, means):
"""Compute nearest neighbors and loss for training the embeddings.
Args:
x: Batch of encoder continuous latent states sliced/projected into
shape
[-1, num_blocks, block_dim].
means: Embedding means.
Returns:
The nearest neighbor in one hot form, the nearest neighbor
itself, the
commitment loss, embedding training loss.
"""
x_means_hot = self.nearest_neighbor(x, means)
x_means_hot_flat = tf.reshape(
x_means_hot, [-1, self.hparams.num_blocks, self.hparams.block_v_size])
x_means = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 0, 2]), means)
x_means = tf.transpose(x_means, [1, 0, 2])
q_loss = tf.reduce_mean(
tf.squared_difference(tf.stop_gradient(x), x_means))
e_loss = tf.reduce_mean(
tf.squared_difference(x, tf.stop_gradient(x_means)))
return x_means_hot, x_means, q_loss, e_loss
def bit_to_int(self, x_bit, num_bits, base=2):
"""Turn x_bit representing numbers bitwise (lower-endian) to int tensor.
Args:
x_bit: Tensor containing numbers in a particular base to be
converted to
int.
num_bits: Number of bits in the representation.
base: Base of the representation.
Returns:
Integer representation of this number.
"""
x_l = tf.stop_gradient(tf.to_int32(tf.reshape(x_bit, [-1, num_bits])))
# pylint: disable=g-complex-comprehension
x_labels = [
x_l[:, i] * tf.to_int32(base)**tf.to_int32(i) for i in range(num_bits)]
res = sum(x_labels)
return tf.to_int32(tf.reshape(res, common_layers.shape_list(x_bit)[:-1]))
def int_to_bit(self, x_int, num_bits, base=2):
"""Turn x_int representing numbers into a bitwise (lower-endian) tensor.
Args:
x_int: Tensor containing integer to be converted into base
notation.
num_bits: Number of bits in the representation.
base: Base of the representation.
Returns:
Corresponding number expressed in base.
"""
x_l = tf.to_int32(tf.expand_dims(x_int, axis=-1))
# pylint: disable=g-complex-comprehension
x_labels = [
tf.floormod(
tf.floordiv(tf.to_int32(x_l),
tf.to_int32(base)**i), tf.to_int32(base))
for i in range(num_bits)]
res = tf.concat(x_labels, axis=-1)
return tf.to_float(res)
def embed(self, x):
"""Embedding function that takes discrete latent and returns embedding.
Args:
x: Input to the discretization bottleneck.
Returns:
Continuous embedding to be passed on to the decoder.
Raises:
ValueError: For unknown or missing arguments.
"""
shape_x = common_layers.shape_list(x)
x_flat = tf.reshape(x, [-1, 1])
c = self.int_to_bit(x_flat, num_bits=self.hparams.z_size, base=2)
shape = common_layers.shape_list(c)
new_shape = shape
new_shape.append(self.hparams.num_blocks)
new_shape.append(int(self.hparams.z_size / self.hparams.num_blocks))
c = tf.to_int32(tf.reshape(c, shape=new_shape))
h1_shape = shape_x
h1_shape.append(self.hparams.hidden_size)
h1 = tf.zeros(dtype=tf.float32, shape=h1_shape)
c_int = self.bit_to_int(
c, num_bits=int(self.hparams.z_size / self.hparams.num_blocks), base=2)
c_hot = tf.one_hot(c_int, depth=self.hparams.block_v_size, axis=-1)
c_hot_flat = tf.reshape(
c_hot, shape=[-1, self.hparams.num_blocks, self.hparams.block_v_size])
h1 = tf.matmul(tf.transpose(c_hot_flat, perm=[1, 0, 2]), self.means)
h1 = tf.transpose(h1, perm=[1, 0, 2])
h1 = tf.reshape(h1, shape=h1_shape)
h1_shape[0] = self.hparams.batch_size
h2 = tf.layers.dense(tf.nn.relu(h1), self.hparams.filter_size, name="vch2")
res = tf.layers.dense(
tf.nn.relu(h2), self.hparams.hidden_size, name="vcfin")
return res
def discrete_bottleneck(self, x):
"""Discretization bottleneck for latent variables.
Args:
x: Input to the discretization bottleneck.
Returns:
Embedding to pass to the decoder, discrete latent, loss, and the
embedding
function.
Raises:
ValueError: If projection_tensors is None for reshape_method
project, or
ema_count or ema_means is None if we are using ema, or unknown
args.
"""
x_reshaped = self.slice_hidden(x)
x_means_hot = []
x_means = 0
loss = 0
x_means_hot, x_means, q_loss, e_loss = self.embedding_lookup(
x_reshaped, self.means)
if self.hparams.ema:
tf.logging.info("Using EMA with beta = {}".format(self.hparams.beta))
updated_ema_count = \
moving_averages.assign_moving_average(
self.ema_count,
tf.reduce_sum(
tf.reshape(
x_means_hot,
shape=[-1, self.hparams.num_blocks,
self.hparams.block_v_size]),
axis=0),
self.hparams.decay,
zero_debias=False)
dw = tf.matmul(
tf.transpose(x_means_hot, perm=[1, 2, 0]),
tf.transpose(x_reshaped, perm=[1, 0, 2]))
updated_ema_means = \
moving_averages.assign_moving_average(
self.ema_means, dw, self.hparams.decay,
zero_debias=False)
n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True)
updated_ema_count = ((updated_ema_count + self.hparams.epsilon) / (
n + 2**self.hparams.z_size * self.hparams.epsilon) * n)
updated_ema_means = updated_ema_means / tf.expand_dims(
updated_ema_count, axis=-1)
with tf.control_dependencies([e_loss]):
update_means = tf.assign(self.means, updated_ema_means)
with tf.control_dependencies([update_means]):
loss += self.hparams.beta * e_loss
else:
# Use a gradient based loss for learning the cluster centers
loss += q_loss + self.hparams.beta * e_loss
# Get the discrete latent representation
x_means_idx = tf.argmax(x_means_hot, axis=-1)
# Get the binary representation
num_bits = int(self.hparams.z_size // self.hparams.num_blocks)
x_means_bits = self.int_to_bit(x_means_idx, num_bits=num_bits, base=2)
x_discrete = self.bit_to_int(
tf.to_int32(x_means_bits), num_bits=self.hparams.z_size, base=2)
# Reshape x_discrete
shape_x = common_layers.shape_list(x)
shape_discrete = shape_x[:-1]
x_discrete = tf.reshape(x_discrete, shape_discrete)
x_means = tf.reshape(x_means, shape=shape_x)
h1 = x + tf.stop_gradient(x_means - x)
h2 = tf.layers.dense(tf.nn.relu(h1), self.hparams.filter_size, name="vch2")
res = tf.layers.dense(
tf.nn.relu(h2), self.hparams.hidden_size, name="vcfin")
embed_fn = partial(self.embed)
return {
"dense": res,
"discrete": x_discrete,
"loss": loss,
"embed": embed_fn
}
================================================
FILE: tensor2tensor/layers/vqa_layers.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some customization of common_attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import contrib
import tensorflow.compat.v1 as tf
from tensorflow.contrib import slim
from tensorflow.contrib.slim.python.slim.nets.resnet_v1 import resnet_v1_152
from tensorflow.contrib.slim.python.slim.nets.resnet_v2 import resnet_v2_152 # pylint: disable=unused-import
from tensorflow.python.ops import inplace_ops
def summarize_tensors(tensor_dict, tag=None):
"""Summarize the tensors.
Args:
tensor_dict: a dictionary of tensors.
tag: name scope of the summary; defaults to tensors/.
"""
if tag is None:
tag = "tensors/"
for t_name in list(tensor_dict):
t = tensor_dict[t_name]
tf.summary.histogram(tag + t_name, t)
def image_embedding(images,
model_fn=resnet_v1_152,
trainable=True,
is_training=True,
weight_decay=0.0001,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
add_summaries=False,
reuse=False):
"""Extract image features from pretrained resnet model."""
is_resnet_training = trainable and is_training
batch_norm_params = {
"is_training": is_resnet_training,
"trainable": trainable,
"decay": batch_norm_decay,
"epsilon": batch_norm_epsilon,
"scale": batch_norm_scale,
}
if trainable:
weights_regularizer = contrib.layers().l2_regularizer(weight_decay)
else:
weights_regularizer = None
with tf.variable_scope(model_fn.__name__, [images], reuse=reuse) as scope:
with slim.arg_scope(
[slim.conv2d],
weights_regularizer=weights_regularizer,
trainable=trainable):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.batch_norm],
is_training=is_resnet_training,
trainable=trainable):
with slim.arg_scope([slim.max_pool2d], padding="SAME"):
net, end_points = model_fn(
images, num_classes=None, global_pool=False,
is_training=is_resnet_training,
reuse=reuse, scope=scope)
if add_summaries:
for v in end_points.values():
contrib.layers().summaries.summarize_activation(v)
return net
def multihead_attention(query_antecedent,
memory_antecedent,
bias,
total_key_depth,
total_value_depth,
output_depth,
num_heads,
dropout_rate,
shared_rel=False,
max_relative_position=None,
image_shapes=None,
attention_type="dot_product",
block_length=128,
block_width=128,
q_filter_width=1,
kv_filter_width=1,
q_padding="VALID",
kv_padding="VALID",
cache=None,
gap_size=0,
num_memory_blocks=2,
name="multihead_attention",
save_weights_to=None,
make_image_summary=True,
dropout_broadcast_dims=None,
max_length=None,
vars_3d=False,
scale_dotproduct=True,
**kwargs):
"""Multihead scaled-dot-product attention with input/output transformations.
Args:
query_antecedent: a Tensor with shape [batch, length_q, channels]
memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
bias: bias Tensor (see attention_bias())
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
shared_rel: boolean to share relative embeddings
max_relative_position: Maximum distance between inputs to generate
unique relation embeddings for. Only relevant
when using "dot_product_relative" attention.
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
attention_type: a string, either "dot_product", "dot_product_relative",
"local_mask_right", "local_unmasked", "masked_dilated_1d",
"unmasked_dilated_1d", graph, or any attention function
with the signature (query, key, value, **kwargs)
block_length: an integer - relevant for "local_mask_right"
block_width: an integer - relevant for "local_unmasked"
q_filter_width: An integer specifying how wide you want the query to be.
kv_filter_width: An integer specifying how wide you want the keys and values
to be.
q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
no padding.
cache: dict containing Tensors which are the results of previous
attentions, used for fast decoding. Expects the dict to contrain two
keys ('k' and 'v'), for the initial call the values for these keys
should be empty Tensors of the appropriate shape.
'k' [batch_size, 0, key_channels]
'v' [batch_size, 0, value_channels]
gap_size: Integer option for dilated attention to indicate spacing between
memory blocks.
num_memory_blocks: Integer option to indicate how many memory blocks to look
at.
name: an optional string.
save_weights_to: an optional dictionary to capture attention weights
for vizualization; the weights tensor will be appended there under
a string key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
dropout_broadcast_dims: an optional list of integers less than 4
specifying in which dimensions to broadcast the dropout decisions.
saves memory.
max_length: an integer - needed by relative attention
vars_3d: use 3-dimensional variables for input/output transformations
scale_dotproduct: whether to normalize the attention product.
**kwargs (dict): Parameters for the attention function
Caching:
WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
the caching assumes that the bias contains future masking.
The caching works by saving all the previous key and value values so that
you are able to send just the last query location to this attention
function. I.e. if the cache dict is provided it assumes the query is of the
shape [batch_size, 1, hidden_dim] rather than the full memory.
Returns:
The result of the attention transformation. The output shape is
[batch_size, length_q, hidden_dim]
unless the cache dict is provided in which case only the last memory
position is calculated and the output shape is [batch_size, 1, hidden_dim]
Optionally returns an additional loss parameters (ex: load balance loss for
the experts) returned by the attention_type function.
Raises:
ValueError: if the key depth or value depth are not divisible by the
number of attention heads.
"""
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
vars_3d_num_heads = num_heads if vars_3d else 0
with tf.variable_scope(name, default_name="multihead_attention",
values=[query_antecedent, memory_antecedent]):
if cache is None or memory_antecedent is None:
q, k, v = common_attention.compute_qkv(
query_antecedent, memory_antecedent,
total_key_depth, total_value_depth, q_filter_width,
kv_filter_width, q_padding, kv_padding,
vars_3d_num_heads=vars_3d_num_heads)
if cache is not None:
if attention_type != "dot_product":
# TODO(petershaw): Support caching when using relative position
# representations, i.e. "dot_product_relative" attention.
raise NotImplementedError(
"Caching is not guaranteed to work with attention types other than"
" dot_product.")
if bias is None:
raise ValueError("Bias required for caching. See function docstring "
"for details.")
if memory_antecedent is not None:
# Encoder-Decoder Attention Cache
q = common_attention.compute_attention_component(
query_antecedent, total_key_depth,
q_filter_width, q_padding, "q",
vars_3d_num_heads=vars_3d_num_heads)
k = cache["k_encdec"]
v = cache["v_encdec"]
else:
k = common_attention.split_heads(k, num_heads)
v = common_attention.split_heads(v, num_heads)
decode_loop_step = kwargs.get("decode_loop_step")
if decode_loop_step is None:
k = cache["k"] = tf.concat([cache["k"], k], axis=2)
v = cache["v"] = tf.concat([cache["v"], v], axis=2)
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
# The performance of current implementation is better than updating
# the tensor by adding the result of matmul(one_hot,
# update_in_current_step)
tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
tmp_k = inplace_ops.alias_inplace_update(
tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
tmp_v = inplace_ops.alias_inplace_update(
tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])
q = common_attention.split_heads(q, num_heads)
if cache is None:
k = common_attention.split_heads(k, num_heads)
v = common_attention.split_heads(v, num_heads)
key_depth_per_head = total_key_depth // num_heads
if not vars_3d:
if scale_dotproduct:
q *= key_depth_per_head**-0.5
additional_returned_value = None
if callable(attention_type): # Generic way to extend multihead_attention
x = attention_type(q, k, v, **kwargs)
if isinstance(x, tuple):
x, additional_returned_value = x # Unpack
elif attention_type == "dot_product":
x = common_attention.dot_product_attention(
q, k, v, bias, dropout_rate, image_shapes,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=dropout_broadcast_dims)
elif attention_type == "dot_product_relative":
x = common_attention.dot_product_attention_relative(
q,
k,
v,
bias,
max_relative_position,
dropout_rate,
image_shapes,
make_image_summary=make_image_summary)
elif attention_type == "dot_product_relative_v2":
x = common_attention.dot_product_self_attention_relative_v2(
q,
k,
v,
bias,
max_length,
dropout_rate,
image_shapes,
make_image_summary=make_image_summary,
dropout_broadcast_dims=dropout_broadcast_dims)
elif attention_type == "local_within_block_mask_right":
x = common_attention.masked_within_block_local_attention_1d(
q, k, v, block_length=block_length)
elif attention_type == "rel_local_mask_right":
x = common_attention.masked_rel_local_attention_1d(
q, k, v, block_length=block_length,
make_image_summary=make_image_summary,
dropout_rate=dropout_rate,
share_rel_embed=shared_rel)
elif attention_type == "local_mask_right":
x = common_attention.masked_local_attention_1d(
q,
k,
v,
block_length=block_length,
make_image_summary=make_image_summary)
elif attention_type == "local_unmasked":
x = common_attention.local_attention_1d(
q, k, v, block_length=block_length, filter_width=block_width)
elif attention_type == "masked_dilated_1d":
x = common_attention.masked_dilated_self_attention_1d(
q, k, v, block_length, block_width,
gap_size, num_memory_blocks)
else:
assert attention_type == "unmasked_dilated_1d"
x = common_attention.dilated_self_attention_1d(
q, k, v, block_length, block_width,
gap_size, num_memory_blocks)
x = common_attention.combine_heads(x)
# Set last dim specifically.
x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])
if vars_3d:
o_var = tf.get_variable(
"o", [num_heads, total_value_depth // num_heads, output_depth])
o_var = tf.cast(o_var, x.dtype)
o_var = tf.reshape(o_var, [total_value_depth, output_depth])
x = tf.tensordot(x, o_var, axes=1)
else:
x = common_layers.dense(
x, output_depth, use_bias=False, name="output_transform")
if additional_returned_value is not None:
return x, additional_returned_value
return x
================================================
FILE: tensor2tensor/metrics/__init__.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: tensor2tensor/metrics/video_conditional_fvd.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conditional FVD metric on video.
FVD - Frechet Video Distance
This is the metric that is inspired by FID, but applied to
video rather than to images.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
class VideoEvaluationDataset(
collections.namedtuple(
'VideoEvaluationDataset',
['n_input_frames', 'n_output_frames', 'get_video_batch_fn'])):
"""Dataset for video evaluation.
This tuple describes the video problem for Evaluation.
Args:
n_input_frames: number of frames passed to the model to condition on.
n_output_frames: number of frames that model should return.
get_video_batch_fn: function that accepts a batch size and returns a tensor
with real video, which should match [batch_size, N, height, width,
depth], where N is n_input_frames + n_output_frames.
"""
pass
class Model(
collections.namedtuple('Model', [
'apply_fn', 'load_fn',
])):
"""Model that should be evaluated.
Args:
apply_fn: will be called with a single tensor (floats between 0 and 255
of shape [batch_size, n_input_frames, height, width, depth]),
that will contain input frames.
it should return a single tensor with output frames (floats
between 0 and 255, of shape
[batch_size, n_output_frames, height, width, depth])
load_fn: Callable, that receives session as an argument.
Should load the variables from the checkpoint.
"""
pass
def evaluate_model(video_eval_dataset, model, num_batches, batch_size):
"""Computes the FVD video metric.
Args:
video_eval_dataset: VideoEvaluationDataset tuple with video and frames
information.
model: Model tuple with model to evaluate.
num_batches: number of batches to evaluate.
batch_size: number of videos to compute per batch.
Returns:
FVD metric (float).
"""
del video_eval_dataset, model, num_batches, batch_size
================================================
FILE: tensor2tensor/metrics/video_conditional_fvd_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for video_conditional_fvd."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.metrics import video_conditional_fvd
import tensorflow.compat.v1 as tf
class VideoConditionalFvdTest(tf.test.TestCase):
def test_sample(self):
dataset = video_conditional_fvd.VideoEvaluationDataset(
n_input_frames=4,
n_output_frames=10,
get_video_batch_fn=None)
model = video_conditional_fvd.Model(
apply_fn=None,
load_fn=None)
video_conditional_fvd.evaluate_model(dataset, model, 10, 16)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: tensor2tensor/models/README.md
================================================
# Constructing T2T Models.
This directory contains T2T models, their hyperparameters, and a number
of common layers and hyperparameter settings to help construct new models.
Common building blocks are in `common_layers.py` and `common_attention.py`.
Common hyperparameters are in `common_hparams.py`. Models are imported in
`__init__.py`.
## Adding a new model.
To add a model to the built-in set, create a new file (see, e.g.,
`neural_gpu.py`) and write your model class inheriting from `T2TModel` there and
decorate it with `registry.register_model`. Import it in `__init__.py`.
It is now available to use with the trainer binary (`t2t-trainer`) using the
`--model=model_name` flag.
================================================
FILE: tensor2tensor/models/__init__.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Models defined in T2T. Imports here force registration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
# pylint: disable=unused-import
from tensor2tensor.layers import modalities # pylint: disable=g-import-not-at-top
from tensor2tensor.models import basic
from tensor2tensor.models import bytenet
from tensor2tensor.models import distillation
from tensor2tensor.models import evolved_transformer
from tensor2tensor.models import image_transformer
from tensor2tensor.models import image_transformer_2d
from tensor2tensor.models import lstm
from tensor2tensor.models import neural_assistant
from tensor2tensor.models import neural_gpu
from tensor2tensor.models import resnet
from tensor2tensor.models import revnet
from tensor2tensor.models import shake_shake
from tensor2tensor.models import slicenet
from tensor2tensor.models import text_cnn
from tensor2tensor.models import transformer
from tensor2tensor.models import vanilla_gan
from tensor2tensor.models import xception
from tensor2tensor.models.neural_architecture_search import nas_model
from tensor2tensor.models.research import adafactor_experiments
from tensor2tensor.models.research import aligned
from tensor2tensor.models.research import autoencoders
from tensor2tensor.models.research import cycle_gan
from tensor2tensor.models.research import gene_expression
from tensor2tensor.models.research import neural_stack
from tensor2tensor.models.research import residual_shuffle_exchange
from tensor2tensor.models.research import rl
from tensor2tensor.models.research import shuffle_network
from tensor2tensor.models.research import similarity_transformer
from tensor2tensor.models.research import super_lm
from tensor2tensor.models.research import transformer_moe
from tensor2tensor.models.research import transformer_nat
from tensor2tensor.models.research import transformer_parallel
from tensor2tensor.models.research import transformer_revnet
from tensor2tensor.models.research import transformer_seq2edits
from tensor2tensor.models.research import transformer_sketch
from tensor2tensor.models.research import transformer_symshard
from tensor2tensor.models.research import transformer_vae
from tensor2tensor.models.research import universal_transformer
from tensor2tensor.models.video import basic_deterministic
from tensor2tensor.models.video import basic_recurrent
from tensor2tensor.models.video import basic_stochastic
from tensor2tensor.models.video import emily
from tensor2tensor.models.video import savp
from tensor2tensor.models.video import sv2p
from tensor2tensor.utils import contrib
from tensor2tensor.utils import registry
# The following models can't be imported under TF2
if not contrib.is_tf2:
# pylint: disable=g-import-not-at-top
from tensor2tensor.models.research import attention_lm
from tensor2tensor.models.research import attention_lm_moe
from tensor2tensor.models.research import glow
from tensor2tensor.models.research import lm_experiments
from tensor2tensor.models.research import moe_experiments
from tensor2tensor.models.research import multiquery_paper
from tensor2tensor.models import mtf_image_transformer
from tensor2tensor.models import mtf_resnet
from tensor2tensor.models import mtf_transformer
from tensor2tensor.models import mtf_transformer2
from tensor2tensor.models.research import vqa_attention
from tensor2tensor.models.research import vqa_recurrent_self_attention
from tensor2tensor.models.research import vqa_self_attention
from tensor2tensor.models.video import epva
from tensor2tensor.models.video import next_frame_glow
# pylint: enable=g-import-not-at-top
# pylint: disable=unused-import
# pylint: enable=unused-import
def model(name):
return registry.model(name)
================================================
FILE: tensor2tensor/models/basic.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic models for testing simple tasks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
@registry.register_model
class BasicFcRelu(t2t_model.T2TModel):
"""Basic fully-connected + ReLU model."""
def body(self, features):
hparams = self.hparams
x = features["inputs"]
shape = common_layers.shape_list(x)
x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]])
for i in range(hparams.num_hidden_layers):
x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i)
x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
x = tf.nn.relu(x)
return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.
@registry.register_hparams
def basic_fc_small():
"""Small fully connected model."""
hparams = common_hparams.basic_params1()
hparams.learning_rate = 0.1
hparams.batch_size = 128
hparams.hidden_size = 256
hparams.num_hidden_layers = 2
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
hparams.weight_decay = 0.0
hparams.dropout = 0.0
return hparams
================================================
FILE: tensor2tensor/models/basic_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic nets tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.data_generators import mnist # pylint: disable=unused-import
from tensor2tensor.models import basic
from tensor2tensor.utils import trainer_lib
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class BasicTest(tf.test.TestCase):
def testBasicFcRelu(self):
x = np.random.randint(256, size=(1, 28, 28, 1))
y = np.random.randint(10, size=(1, 1))
hparams = trainer_lib.create_hparams(
"basic_fc_small", problem_name="image_mnist", data_dir=".")
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = basic.BasicFcRelu(hparams, tf_estimator.ModeKeys.TRAIN)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (1, 1, 1, 1, 10))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/bytenet.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ByteNet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
def residual_dilated_conv(x, repeat, padding, name, hparams):
"""A stack of convolution blocks with residual connections."""
with tf.variable_scope(name):
k = (hparams.kernel_height, hparams.kernel_width)
dilations_and_kernels = [((2**i, 1), k)
for i in range(hparams.num_hidden_layers)]
for i in range(repeat):
with tf.variable_scope("repeat_%d" % i):
y = common_layers.conv_block(
common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"),
hparams.hidden_size,
dilations_and_kernels,
padding=padding,
name="residual_conv")
y = tf.nn.dropout(y, 1.0 - hparams.dropout)
x += y
return x
def bytenet_internal(inputs, targets, hparams):
"""ByteNet, main step used for training."""
with tf.variable_scope("bytenet"):
# Flatten inputs and extend length by 50%.
inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1]))
inputs_shape = inputs.shape.as_list()
inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]])
inputs_shape[1] = None
inputs.set_shape(inputs_shape) # Don't lose the other shapes when padding.
# Pad inputs and targets to be the same length, divisible by 50.
inputs, targets = common_layers.pad_to_same_length(
inputs, targets, final_length_divisible_by=50)
final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat,
"SAME", "encoder", hparams)
shifted_targets = common_layers.shift_right(targets)
kernel = (hparams.kernel_height, hparams.kernel_width)
decoder_start = common_layers.conv_block(
tf.concat([final_encoder, shifted_targets], axis=3),
hparams.hidden_size, [((1, 1), kernel)],
padding="LEFT")
return residual_dilated_conv(decoder_start, hparams.num_block_repeat,
"LEFT", "decoder", hparams)
@registry.register_model
class ByteNet(t2t_model.T2TModel):
def body(self, features):
return bytenet_internal(features["inputs"], features["targets"],
self._hparams)
@registry.register_hparams
def bytenet_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.batch_size = 2048
hparams.hidden_size = 768
hparams.dropout = 0.2
hparams.symbol_dropout = 0.2
hparams.label_smoothing = 0.1
hparams.clip_grad_norm = 2.0
hparams.num_hidden_layers = 4
hparams.kernel_height = 3
hparams.kernel_width = 1
hparams.learning_rate_decay_scheme = "exp"
hparams.learning_rate = 0.05
hparams.learning_rate_warmup_steps = 3000
hparams.initializer_gain = 1.0
hparams.weight_decay = 3.0
hparams.num_sampled_classes = 0
hparams.sampling_method = "argmax"
hparams.optimizer_adam_epsilon = 1e-6
hparams.optimizer_adam_beta1 = 0.85
hparams.optimizer_adam_beta2 = 0.997
hparams.add_hparam("num_block_repeat", 4)
return hparams
================================================
FILE: tensor2tensor/models/bytenet_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ByteNet tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import bytenet
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class ByteNetTest(tf.test.TestCase):
def testByteNet(self):
vocab_size = 9
x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
hparams = bytenet.bytenet_base()
p_hparams = problem_hparams.test_problem_hparams(vocab_size,
vocab_size,
hparams)
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = bytenet.ByteNet(
hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/distillation.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Traditional Student-Teacher Distillation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_hparams
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
@registry.register_model
class Distillation(t2t_model.T2TModel):
"""Distillation from a teacher to student network.
First, a teacher is trained on a task; Second, a student is trained to perform
the task while matching the teacher's softened outputs. For more details, see
the paper below.
In the hparams passed to this model include the desired
{teacher/student}_model and {teacher/student}_hparams to be used. Also,
specify the distillation temperature and task-distillation balance.
Distilling the Knowledge in a Neural Network
Hinton, Vinyals and Dean
https://arxiv.org/abs/1503.02531
"""
def __init__(self,
hparams,
mode=tf_estimator.ModeKeys.TRAIN,
problem_hparams=None,
data_parallelism=None,
decode_hparams=None,
**kwargs):
assert hparams.distill_phase in ["train", "distill"]
if hparams.distill_phase == "train" and hparams.teacher_learning_rate:
hparams.learning_rate = hparams.teacher_learning_rate
elif hparams.distill_phase == "distill" and hparams.student_learning_rate:
hparams.learning_rate = hparams.student_learning_rate
self.teacher_hparams = registry.hparams(hparams.teacher_hparams)
self.teacher_model = registry.model(
hparams.teacher_model)(self.teacher_hparams, mode, problem_hparams,
data_parallelism, decode_hparams)
self.student_hparams = registry.hparams(hparams.student_hparams)
self.student_model = registry.model(
hparams.student_model)(self.student_hparams, mode, problem_hparams,
data_parallelism, decode_hparams)
super(Distillation,
self).__init__(hparams, mode, problem_hparams, data_parallelism,
decode_hparams, **kwargs)
def body(self, features):
hp = self.hparams
is_distill = hp.distill_phase == "distill"
targets = features["targets_raw"]
targets = tf.squeeze(targets, [1, 2, 3])
one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32)
# Teacher Network
with tf.variable_scope("teacher"):
teacher_outputs = self.teacher_model.body(features)
tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape())
teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2])
teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes)
teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_targets, logits=teacher_logits)
outputs = teacher_logits
if is_distill:
# Load teacher weights
tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"})
# Do not train the teacher
trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
del trainable_vars[:]
# Student Network
if is_distill:
with tf.variable_scope("student"):
student_outputs = self.student_model.body(features)
tf.logging.info(
"student output shape: %s" % student_outputs.get_shape())
student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2])
student_logits = tf.layers.dense(student_outputs, hp.num_classes)
student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_targets, logits=student_logits)
teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature)
student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.stop_gradient(teacher_targets),
logits=student_logits / hp.distill_temperature)
# scale soft target obj. to match hard target obj. scale
student_distill_xent *= hp.distill_temperature**2
outputs = student_logits
# Summaries
tf.summary.scalar("distill_xent", student_distill_xent)
if not is_distill:
phase_loss = teacher_task_xent
else:
phase_loss = hp.task_balance * student_task_xent
phase_loss += (1 - hp.task_balance) * student_distill_xent
losses = {"training": phase_loss}
outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]])
return outputs, losses
def top(self, body_output, features):
return body_output
def distill_base():
"""Set of hyperparameters."""
# Base
hparams = common_hparams.basic_params1()
# teacher/student parameters
hparams.add_hparam("teacher_model", "")
hparams.add_hparam("teacher_hparams", "")
hparams.add_hparam("student_model", "")
hparams.add_hparam("student_hparams", "")
# Distillation parameters
# WARNING: distill_phase hparam will be overwritten in /bin/t2t_distill.py
hparams.add_hparam("distill_phase", None)
hparams.add_hparam("task_balance", 1.0)
hparams.add_hparam("distill_temperature", 1.0)
hparams.add_hparam("num_classes", 10)
# Optional Phase-specific hyperparameters
hparams.add_hparam("teacher_learning_rate", None)
hparams.add_hparam("student_learning_rate", None)
# Training parameters (stolen from ResNet)
hparams.batch_size = 128
hparams.optimizer = "Momentum"
hparams.optimizer_momentum_momentum = 0.9
hparams.optimizer_momentum_nesterov = True
hparams.weight_decay = 1e-4
hparams.clip_grad_norm = 0.0
# (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
hparams.learning_rate = 0.4
hparams.learning_rate_decay_scheme = "cosine"
# For image_imagenet224, 120k training steps, which effectively makes this a
# cosine decay (i.e. no cycles).
hparams.learning_rate_cosine_cycle_steps = 120000
hparams.initializer = "normal_unit_scaling"
hparams.initializer_gain = 2.
return hparams
@registry.register_hparams
def distill_resnet_32_to_15_cifar20x5():
"""Set of hyperparameters."""
hparams = distill_base()
hparams.teacher_model = "resnet"
hparams.teacher_hparams = "resnet_cifar_32"
hparams.student_model = "resnet"
hparams.student_hparams = "resnet_cifar_15"
hparams.optimizer_momentum_nesterov = True
# (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
hparams.teacher_learning_rate = 0.25 * 128. * 8. / 256.
hparams.student_learning_rate = 0.2 * 128. * 8. / 256.
hparams.learning_rate_decay_scheme = "piecewise"
hparams.add_hparam("learning_rate_boundaries", [40000, 60000, 80000])
hparams.add_hparam("learning_rate_multiples", [0.1, 0.01, 0.001])
hparams.task_balance = 0.28
hparams.distill_temperature = 2.0
hparams.num_classes = 20
return hparams
================================================
FILE: tensor2tensor/models/evolved_transformer.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evolved Transformer model.
This implements the model described in arxiv.org/abs/1901.11117 .
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.models import transformer
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import inplace_ops
# pylint: enable=g-direct-tensorflow-import
_CONV_BRANCHES_NAME = "conv_branches"
_CONV_BRANCHES_FIRST_LAYER_NAME = _CONV_BRANCHES_NAME + "_first"
_CONV_BRANCHES_SECOND_LAYER_NAME = _CONV_BRANCHES_NAME + "_second"
_FIRST_ATTEND_TO_ENCODER_NAME = "first_attend_to_encoder"
_SECOND_ATTEND_TO_ENCODER_NAME = "second_attend_to_encoder"
_SIXTEEN_HEAD_ATTENTION_NAME = "16_head_self_attention"
_VANILLA_ATTENTION_NAME = "self_attention"
_DECODER_LEFT_CONV_PADDING = 10
_DECODER_RIGHT_CONV_PADDING = 6
_DECODER_FINAL_CONV_PADDING = 6
def _capped_double_heads(num_heads, cap=16):
"""Calculate the number of heads for the attention layers with more heads.
The number of heads will be twice the normal amount (num_heads), until it
reaches |cap| heads.
Args:
num_heads: the num_heads hparam for the model.
cap: the maximum number of heads |num_heads| will be doubled to.
Returns:
The number of heads for the attention layers that have more heads.
"""
return max(min(num_heads * 2, cap), num_heads)
@registry.register_model
class EvolvedTransformer(transformer.Transformer):
"""The Evolved Transformer from arxiv.org/abs/1901.11117 ."""
def __init__(self, *args, **kwargs):
super(EvolvedTransformer, self).__init__(*args, **kwargs)
self._encoder_function = evolved_transformer_encoder
self._decoder_function = evolved_transformer_decoder
self._init_cache_fn = init_evolved_transformer_cache
# -1 means train all weights.
if self.hparams.get("num_trainable_top_decoder_layers", -1) < 0:
t2t_model.log_info(
"num_trainable_top_decoder_layers is negative so training all weights."
)
elif self.hparams.shared_embedding_and_softmax_weights:
t2t_model.log_info(
"Setting hparams.shared_embedding_and_softmax_weights to False, "
"because hparam.num_trainable_top_decoder_layers is being used.")
# When hparam.num_trainable_top_decoder_layers is set to N >= 0 we will
# freeze (not train) every variable except the N top decoder layers and
# the (pre-)softmax matrix. For any N >= 0 we will freeze the encoder and
# input/target embeddings. This also means we will not share the
# (pre-)softmax matrix with input/target embeddings otherwise they will be
# trained as well.
self.hparams.shared_embedding_and_softmax_weights = False
# If hparams.shared_embedding_and_softmax_weights was previously True,
# then input and target embeddings were being shared.
# To make sure it they embeddings continue to be shared, we need to set
# hparams.shared_embedding to True.
self.hparams.shared_embedding = True
self._init_cache_fn = init_evolved_transformer_cache
def evolved_transformer_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
name="encoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None,
attn_bias_for_padding=None):
"""Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.
Note: Pad remover is not supported.
Args:
encoder_input: a Tensor.
encoder_self_attention_bias: bias Tensor for self-attention (see
common_attention.attention_bias()).
hparams: hyperparameters for model.
name: a string.
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This must either be passed in,
which we do for "packed" datasets, or inferred from
encoder_self_attention_bias. The knowledge about padding is used for
pad_remover(efficiency) and to mask out padding in convolutional layers.
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: Not used.
attn_bias_for_padding: Padded attention bias in case a unidirectional
encoder is being used where future attention is masked.
Returns:
Tensor encoder output.
"""
del losses
hidden_state = encoder_input
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
if nonpadding is not None:
padding = 1.0 - nonpadding
else:
attention_bias = encoder_self_attention_bias
if attn_bias_for_padding is not None:
attention_bias = attn_bias_for_padding
# Only bfloat16 and float32 supported.
float_type = hparams.get("activation_dtype", "float32")
if float_type == "bfloat16":
cast_fn = tf.to_bfloat16
else:
assert float_type == "float32"
cast_fn = tf.to_float
padding = common_attention.attention_bias_to_padding(
attention_bias, cast_fn)
nonpadding = 1.0 - padding
for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % layer):
with tf.variable_scope("gated_linear_unit"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
values = common_layers.layers().Dense(
hparams.hidden_size)(hidden_state)
gates = common_layers.layers().Dense(
hparams.hidden_size, activation=tf.nn.sigmoid)(hidden_state)
hidden_state = values * gates
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("conv_branches"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
hidden_state *= mask
left_output_dim = int(hparams.hidden_size * 4)
left_state = common_layers.layers().Dense(
left_output_dim, activation=tf.nn.relu)(hidden_state)
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_output_dim = int(hparams.hidden_size / 2)
right_state = common_layers.layers().Conv1D(
right_output_dim,
3,
padding="SAME",
name="standard_conv_3x1",
activation=tf.nn.relu)(hidden_state)
right_state = tf.nn.dropout(right_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.pad(
right_state,
[[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
constant_values=0)
hidden_state = left_state + right_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
# Mask padding from conv layer.
mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim])
hidden_state *= mask
separable_conv_9x1 = common_layers.layers().SeparableConv1D(
right_output_dim, 9, padding="SAME", name="separable_conv_9x1")
hidden_state = separable_conv_9x1(hidden_state)
hidden_state = tf.pad(
hidden_state,
[[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]],
constant_values=0)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
if hparams.get("et_encoder_self_attention", True):
with tf.variable_scope("self_attention"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = common_attention.multihead_attention(
hidden_state,
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("dense_layers"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = common_layers.layers().Dense(
int(hparams.hidden_size * 4), activation=tf.nn.relu)(hidden_state)
hidden_state = tf.nn.dropout(hidden_state,
1 - hparams.layer_prepostprocess_dropout)
hidden_state = common_layers.layers().Dense(
hparams.hidden_size)(hidden_state)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
# If normalization is done in layer_preprocess, then it should also be done
# on the output, since the output can grow very large, being the sum of
# a whole stack of unnormalized layer outputs.
return common_layers.layer_preprocess(hidden_state, hparams)
def evolved_transformer_decoder(decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=None,
decode_loop_step=None,
name="decoder",
nonpadding=None,
save_weights_to=None,
make_image_summary=True,
losses=None):
"""Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.
Args:
decoder_input: a Tensor.
encoder_output: a Tensor.
decoder_self_attention_bias: bias Tensor for self-attention (see
common_attention.attention_bias()).
encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
(see common_attention.attention_bias()).
hparams: hyperparameters for model.
cache: dict, containing tensors which are the results of previous
layers, used for fast decoding.
decode_loop_step: An integer, step number of the decoding loop. Only used
for inference on TPU.
name: a string.
nonpadding: optional Tensor with shape [batch_size, encoder_length]
indicating what positions are not padding. This is used to mask out
padding in convolutional layers. We generally only need this mask for
"packed" datasets, because for ordinary datasets, no padding is ever
followed by nonpadding.
save_weights_to: an optional dictionary to capture attention weights for
visualization; the weights tensor will be appended there under a string
key created from the variable scope (including name).
make_image_summary: Whether to make an attention image summary.
losses: Not supported.
Returns:
Decoder output tensor.
"""
del losses
num_trainable_top_decoder_layers = hparams.get(
"num_trainable_top_decoder_layers", -1) # -1 means train all weights.
if num_trainable_top_decoder_layers >= 0:
encoder_output = tf.stop_gradient(encoder_output)
attention_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
with tf.variable_scope(name):
hidden_state = decoder_input
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
for layer in range(num_layers):
if num_trainable_top_decoder_layers == num_layers - layer:
hidden_state = tf.stop_gradient(hidden_state)
layer_name = "layer_%d" % layer
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = layer_cache[
_SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None
left_state = common_attention.multihead_attention(
hidden_state,
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
_capped_double_heads(hparams.num_heads),
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
decode_loop_step=decode_loop_step,
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
if encoder_output is not None:
with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME):
attention_cache = (
layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME]
if layer_cache is not None else None)
right_state = common_attention.multihead_attention(
hidden_state,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.nn.dropout(
right_state, 1 - hparams.layer_prepostprocess_dropout)
hidden_state = residual_state + left_state + right_state
else:
hidden_state = common_layers.layer_postprocess(
residual_state, left_state, hparams)
with tf.variable_scope(_CONV_BRANCHES_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if nonpadding is not None:
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
hidden_state *= mask
if layer_cache:
if decode_loop_step is None:
hidden_state = layer_cache[
_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat(
[
layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME],
hidden_state
],
axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :]
left_state = hidden_state
right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING -
_DECODER_RIGHT_CONV_PADDING:, :]
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
tmp = tf.transpose(
layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2])
tmp = tf.expand_dims(tmp, axis=1)
tmp = inplace_ops.alias_inplace_update(
tmp,
decode_loop_step * tf.shape(hidden_state)[1] +
_DECODER_LEFT_CONV_PADDING,
tf.transpose(hidden_state, perm=[1, 0, 2]))
tmp = tf.squeeze(tmp, axis=1)
hidden_state = layer_cache[
_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose(
tmp, perm=[1, 0, 2])
batch_size = hidden_state.shape.as_list()[0]
left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
batch_size, _DECODER_LEFT_CONV_PADDING + 1,
hparams.hidden_size
])
right_state = tf.slice(hidden_state, [
0, decode_loop_step + _DECODER_LEFT_CONV_PADDING -
_DECODER_RIGHT_CONV_PADDING, 0
], [
batch_size, _DECODER_RIGHT_CONV_PADDING + 1,
hparams.hidden_size
])
else: # No caching.
left_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]])
right_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]])
left_output_dim = int(hparams.hidden_size * 2)
separable_conv_11x1 = tf.layers.SeparableConv1D(
left_output_dim,
11,
padding="VALID",
name="separable_conv11x1",
activation=tf.nn.relu)
left_state = separable_conv_11x1.apply(left_state)
left_state = tf.nn.dropout(left_state,
1 - hparams.layer_prepostprocess_dropout)
right_output_dim = int(hparams.hidden_size / 2)
separable_conv_7x1_1 = tf.layers.SeparableConv1D(
right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1")
right_state = separable_conv_7x1_1.apply(right_state)
right_state = tf.nn.dropout(right_state,
1 - hparams.layer_prepostprocess_dropout)
right_state = tf.pad(
right_state,
[[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
constant_values=0)
hidden_state = left_state + right_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if nonpadding is not None:
# Mask padding from conv layers.
mask = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2])
hidden_state *= mask
if layer_cache:
if decode_loop_step is None:
hidden_state = layer_cache[
_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat(
[
layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME],
hidden_state
],
axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :]
else:
# Inplace update is required for inference on TPU.
# Inplace_ops only supports inplace_update on the first dimension.
tmp = tf.transpose(
layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2])
tmp = tf.expand_dims(tmp, axis=1)
tmp = inplace_ops.alias_inplace_update(
tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) *
tf.shape(hidden_state)[1],
tf.transpose(hidden_state, perm=[1, 0, 2]))
tmp = tf.squeeze(tmp, axis=1)
hidden_state = layer_cache[
_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose(
tmp, perm=[1, 0, 2])
batch_size = hidden_state.shape.as_list()[0]
hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
batch_size, _DECODER_FINAL_CONV_PADDING + 1,
hparams.hidden_size * 2
])
else:
hidden_state = tf.pad(
hidden_state,
paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]])
separable_conv_7x1_2 = tf.layers.SeparableConv1D(
hparams.hidden_size,
7,
padding="VALID",
name="separable_conv_7x1_2")
hidden_state = separable_conv_7x1_2.apply(hidden_state)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope(_VANILLA_ATTENTION_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = layer_cache[
_VANILLA_ATTENTION_NAME] if layer_cache is not None else None
hidden_state = common_attention.multihead_attention(
hidden_state,
None,
decoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
decode_loop_step=decode_loop_step,
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
if encoder_output is not None:
with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
attention_cache = (
layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME]
if layer_cache is not None else None)
hidden_state = common_attention.multihead_attention(
hidden_state,
encoder_output,
encoder_decoder_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
cache=attention_cache,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
with tf.variable_scope("dense_layers"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = tf.layers.dense(
hidden_state,
int(hparams.hidden_size * 4),
activation=tf.nn.swish)
hidden_state = tf.nn.dropout(hidden_state,
1 - hparams.layer_prepostprocess_dropout)
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
decoder_output = common_layers.layer_preprocess(hidden_state, hparams)
if num_trainable_top_decoder_layers == 0:
decoder_output = tf.stop_gradient(decoder_output)
return decoder_output
def _add_attend_to_encoder_cache(cache, attention_name, hparams, num_layers,
key_channels, value_channels,
vars_3d_num_heads, scope_prefix,
encoder_output):
"""Add attend-to-encoder layers to cache."""
for layer in range(num_layers):
layer_name = "layer_%d" % layer
with tf.variable_scope("%sdecoder/%s/%s/multihead_attention" %
(scope_prefix, layer_name, attention_name)):
k_encdec = common_attention.compute_attention_component(
encoder_output,
key_channels,
name="k",
vars_3d_num_heads=vars_3d_num_heads)
k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
v_encdec = common_attention.compute_attention_component(
encoder_output,
value_channels,
name="v",
vars_3d_num_heads=vars_3d_num_heads)
v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
cache[layer_name][attention_name] = {
"k_encdec": k_encdec,
"v_encdec": v_encdec
}
return cache
def init_evolved_transformer_cache(cache, hparams, batch_size,
attention_init_length, encoder_output,
encoder_decoder_attention_bias,
scope_prefix):
"""Create the initial cache for Evolved Transformer fast decoding."""
key_channels = hparams.attention_key_channels or hparams.hidden_size
value_channels = hparams.attention_value_channels or hparams.hidden_size
num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
vars_3d_num_heads = (
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
# Add self-attentions.
if cache is None:
cache = {}
cache.update({
"layer_%d" % layer: { # pylint: disable=g-complex-comprehension
_SIXTEEN_HEAD_ATTENTION_NAME: {
"k":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, key_channels]),
_capped_double_heads(hparams.num_heads)),
"v":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, value_channels]),
_capped_double_heads(hparams.num_heads)),
},
_VANILLA_ATTENTION_NAME: {
"k":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, key_channels]),
hparams.num_heads),
"v":
common_attention.split_heads(
tf.zeros(
[batch_size, attention_init_length, value_channels]),
hparams.num_heads),
}
} for layer in range(num_layers)
})
# Add branched layers. Pad with additional zeros for causal convolution.
for layer in range(num_layers):
cache["layer_%d" % layer][_CONV_BRANCHES_FIRST_LAYER_NAME] = tf.zeros([
batch_size, attention_init_length + _DECODER_LEFT_CONV_PADDING,
hparams.hidden_size
])
cache["layer_%d" % layer][_CONV_BRANCHES_SECOND_LAYER_NAME] = tf.zeros([
batch_size, attention_init_length + _DECODER_FINAL_CONV_PADDING,
hparams.hidden_size * 2
])
# Add encoder embedding attentions.
if encoder_output is not None:
cache = _add_attend_to_encoder_cache(
cache=cache,
attention_name=_FIRST_ATTEND_TO_ENCODER_NAME,
hparams=hparams,
num_layers=num_layers,
key_channels=key_channels,
value_channels=value_channels,
vars_3d_num_heads=vars_3d_num_heads,
scope_prefix=scope_prefix,
encoder_output=encoder_output)
cache = _add_attend_to_encoder_cache(
cache=cache,
attention_name=_SECOND_ATTEND_TO_ENCODER_NAME,
hparams=hparams,
num_layers=num_layers,
key_channels=key_channels,
value_channels=value_channels,
vars_3d_num_heads=vars_3d_num_heads,
scope_prefix=scope_prefix,
encoder_output=encoder_output)
cache["encoder_output"] = encoder_output
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
return cache
# TODO(davidso): Update optimizer, learning rate, and decay to match paper.
def add_evolved_transformer_hparams(hparams):
"""Add Evolved Transformer hparams.
Note: These are for the Adam optimizer, not the Adafactor optimizer used in
the paper.
Args:
hparams: Current hparams.
Returns:
hparams updated with Evolved Transformer values.
"""
# Evolved Transformer "layers" are twice as deep as Transformer, so roughly
# halve the number that we use. These numbers are taken from
# arxiv.org/abs/1901.11117 .
hparams.num_encoder_layers = 3
hparams.num_decoder_layers = 4
# Learning rate and decay scheme that mimics the transformer Adam config,
# but with cosine decay instead of rsqrt.
hparams.learning_rate_constant /= hparams.learning_rate_warmup_steps ** 0.5
hparams.learning_rate_schedule = (
"constant*linear_warmup*single_cycle_cos_decay*rsqrt_hidden_size")
return hparams
@registry.register_hparams
def evolved_transformer_tiny():
"""Base parameters for Evolved Transformer model."""
hparams = add_evolved_transformer_hparams(transformer.transformer_tiny())
hparams.learning_rate_schedule = (
"constant*single_cycle_cos_decay")
return hparams
@registry.register_hparams
def evolved_transformer_base():
"""Base parameters for Evolved Transformer model."""
return add_evolved_transformer_hparams(transformer.transformer_base())
@registry.register_hparams
def evolved_transformer_big():
"""Big parameters for Evolved Transformer model on WMT."""
return add_evolved_transformer_hparams(transformer.transformer_big())
@registry.register_hparams
def evolved_transformer_deep():
"""Deep parameters for Evolved Transformer model on WMT."""
hparams = add_evolved_transformer_hparams(transformer.transformer_big())
hparams.num_encoder_layers = 9
hparams.num_decoder_layers = 10
hparams.hidden_size = 640
return hparams
@registry.register_hparams
def evolved_transformer_base_tpu():
"""Base parameters for Evolved Transformer model on TPU."""
hparams = add_evolved_transformer_hparams(transformer.transformer_tpu())
hparams.learning_rate_constant = 1 / hparams.learning_rate_warmup_steps ** 0.5
hparams.learning_rate_schedule = (
"constant*single_cycle_cos_decay")
return hparams
@registry.register_hparams
def evolved_transformer_big_tpu():
"""Big parameters for Evolved Transformer model on TPU."""
hparams = add_evolved_transformer_hparams(transformer.transformer_big_tpu())
hparams.learning_rate_constant = 1 / hparams.learning_rate_warmup_steps ** 0.5
hparams.learning_rate_schedule = (
"constant*single_cycle_cos_decay")
return hparams
@registry.register_hparams
def evolved_transformer_tpu_basic():
"""Basic Seq2Seq TPU hyper-parameters."""
hparams = transformer.transformer_big_tpu()
hparams.add_hparam("print_vars", False)
hparams.batch_size = 8192
hparams.max_length = 256
# N < 0 means all weights in the model are trainable.
# N >= 0 means all weights are frozen except N top decoder layers +
# (pre-)softmax matrix (that projects from hidden size to vocab size).
hparams.add_hparam("num_trainable_top_decoder_layers", -1)
return hparams
================================================
FILE: tensor2tensor/models/evolved_transformer_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Evolved Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import evolved_transformer
from tensor2tensor.models import transformer
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
BATCH_SIZE = 3
INPUT_LENGTH = 5
TARGET_LENGTH = 7
VOCAB_SIZE = 10
DECODE_LENGTH = 3
def print_vars(all_vars=None):
"""Print info about a list of variables."""
if not all_vars:
all_vars = tf.trainable_variables()
tf.logging.info("Format: , , <(soft) device placement>")
for var in all_vars:
tf.logging.info(" %s, %s, %s" %
(var.name, str(var.get_shape()), var.op.device))
def get_var(name):
"""Get trainable variable by name."""
variables = [var for var in tf.trainable_variables() if var.name == name]
if len(variables) == 1:
return variables[0]
raise ValueError("`name` must match exactly one variable. '%s' matched %d" %
(name, len(variables)))
def get_vars(names):
"""Get trainable variables by name."""
return [get_var(name) for name in names]
def assert_with_message(assert_method, a, b, message):
try:
assert_method(a, b)
except AssertionError as e:
tf.logging.error(message)
raise e
def get_model(hparams, has_input=True, num_decoder_layers=1):
hparams.layer_prepostprocess_dropout = 0.0
hparams.hidden_size = 4
hparams.num_heads = 1
hparams.num_encoder_layers = 1
hparams.num_decoder_layers = num_decoder_layers
p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE,
hparams)
if not has_input:
del p_hparams.modality["inputs"]
hparams.problem_hparams = p_hparams
inputs = np.random.randint(VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1))
targets = np.random.randint(
VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
features = {
"targets": tf.constant(targets, dtype=tf.int32, name="targets"),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
if has_input:
features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs")
return (evolved_transformer.EvolvedTransformer(hparams,
tf_estimator.ModeKeys.TRAIN,
p_hparams), features)
class EvolvedTransformerTest(tf.test.TestCase):
def testEvolvedTransformer(self):
model, features = get_model(hparams=transformer.transformer_tiny())
logits, _ = model(features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
def testSlowVsFast(self):
tf.set_random_seed(1234)
model, features = get_model(transformer.transformer_tiny())
decode_length = DECODE_LENGTH
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
loss = tf.reduce_mean(loss)
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
with self.test_session():
tf.global_variables_initializer().run()
for _ in range(10):
apply_grad.run()
model.set_mode(tf_estimator.ModeKeys.PREDICT)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
greedy_result = model._slow_greedy_infer(features,
decode_length)["outputs"]
greedy_result = tf.squeeze(greedy_result, axis=[2, 3])
fast_result = model._greedy_infer(features, decode_length)["outputs"]
with self.test_session():
greedy_res = greedy_result.eval()
fast_res = fast_result.eval()
self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))
self.assertAllClose(greedy_res, fast_res)
def testSlowVsFastNoInput(self):
model, features = get_model(transformer.transformer_tiny(), has_input=False)
decode_length = DECODE_LENGTH
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
loss = tf.reduce_mean(loss)
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
with self.test_session():
tf.global_variables_initializer().run()
for _ in range(10):
apply_grad.run()
model.set_mode(tf_estimator.ModeKeys.PREDICT)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
slow_result = model._slow_greedy_infer(features, decode_length)["outputs"]
slow_result = tf.squeeze(slow_result, axis=[2, 3])
fast_result = model._greedy_infer(features, decode_length)["outputs"]
with self.test_session():
slow_res = slow_result.eval()
fast_res = fast_result.eval()
self.assertEqual(slow_res.shape, (BATCH_SIZE, decode_length))
self.assertAllClose(slow_res, fast_res)
def testBeamVsFast(self):
model, features = get_model(transformer.transformer_tiny())
decode_length = DECODE_LENGTH
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
loss = tf.reduce_mean(loss)
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
with self.test_session():
tf.global_variables_initializer().run()
for _ in range(10):
apply_grad.run()
model.set_mode(tf_estimator.ModeKeys.PREDICT)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
beam_result = model._beam_decode_slow(
features, decode_length, beam_size=4, top_beams=1,
alpha=1.0)["outputs"]
fast_result = model._beam_decode(
features, decode_length, beam_size=4, top_beams=1,
alpha=1.0)["outputs"]
with self.test_session():
beam_res = beam_result.eval()
fast_res = fast_result.eval()
self.assertAllClose(beam_res, fast_res)
def _create_greedy_infer_model(self):
"""Creates model for greedy inference testing.
Returns:
model: A t2t model.
features: An map of string to tensor.
"""
model, features = get_model(transformer.transformer_tiny())
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
loss = tf.reduce_mean(loss)
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
with self.test_session():
tf.global_variables_initializer().run()
for _ in range(10):
apply_grad.run()
model.set_mode(tf_estimator.ModeKeys.PREDICT)
return model, features
def testGreedySlowTPUVsNonTPU(self):
decode_length = DECODE_LENGTH
model, features = self._create_greedy_infer_model()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
slow_result_non_tpu = model._slow_greedy_infer(features,
decode_length)["outputs"]
slow_result_non_tpu = tf.squeeze(slow_result_non_tpu, axis=[2, 3])
slow_result_tpu = model._slow_greedy_infer_tpu(features,
decode_length)["outputs"]
slow_result_tpu = tf.squeeze(slow_result_tpu, axis=[2, 3])
with self.test_session():
slow_non_tpu_res = slow_result_non_tpu.eval()
slow_tpu_res = slow_result_tpu.eval()
self.assertEqual(slow_tpu_res.shape,
(BATCH_SIZE, INPUT_LENGTH + decode_length))
self.assertAllClose(slow_tpu_res, slow_non_tpu_res)
def testGreedyFastTPUVsNonTPU(self):
tf.set_random_seed(1234)
decode_length = DECODE_LENGTH
model, features = self._create_greedy_infer_model()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
fast_result_non_tpu = model._greedy_infer(
features, decode_length, use_tpu=False)["outputs"]
fast_result_tpu = model._greedy_infer(
features, decode_length, use_tpu=True)["outputs"]
with self.test_session():
fast_non_tpu_res = fast_result_non_tpu.eval()
fast_tpu_res = fast_result_tpu.eval()
self.assertEqual(fast_tpu_res.shape,
(BATCH_SIZE, INPUT_LENGTH + decode_length))
self.assertAllClose(fast_tpu_res, fast_non_tpu_res)
def testGreedyTPUSlowVsFast(self):
tf.set_random_seed(1234)
decode_length = DECODE_LENGTH
model, features = self._create_greedy_infer_model()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
slow_result = model._slow_greedy_infer_tpu(features,
decode_length)["outputs"]
slow_result = tf.squeeze(slow_result, axis=[2, 3])
fast_result = model._greedy_infer(
features, decode_length, use_tpu=True)["outputs"]
with self.test_session():
slow_res = slow_result.eval()
fast_res = fast_result.eval()
self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))
self.assertAllClose(fast_res, slow_res)
def testFrozenWeightsUnchangedByTraining(self):
# Arrange.
hparams = transformer.transformer_tiny()
hparams.add_hparam("num_trainable_top_decoder_layers", 1)
model, features = get_model(hparams, num_decoder_layers=3)
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
loss = tf.reduce_mean(loss)
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
frozen_names = [
"evolved_transformer/symbol_modality_10_4/shared/weights_0:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_1:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_2:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_3:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_4:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_5:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_6:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_7:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_8:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_9:0",
"evolved_transformer/body/target_space_embedding/kernel:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/kernel:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/bias:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/kernel:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/dense/kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/dense/bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/depthwise_kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/pointwise_kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/bias:0",
"evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense/kernel:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense/bias:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/bias:0",
"evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/bias:0",
"evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense/kernel:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense/bias:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/bias:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/bias:0",
"evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense/kernel:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense/bias:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/bias:0",
]
train_names = [
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/bias:0",
"evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense/kernel:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense/bias:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/bias:0",
"evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_1:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_2:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_3:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_4:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_5:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_6:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_7:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_8:0",
"evolved_transformer/symbol_modality_10_4/softmax/weights_9:0",
]
frozen_vars = get_vars(frozen_names)
train_vars = get_vars(train_names)
print_vars()
# Act.
with self.test_session() as session:
tf.global_variables_initializer().run()
frozen_values_before = session.run(frozen_vars)
train_values_before = session.run(train_vars)
for _ in range(10): # Arbitrary number of training steps.
apply_grad.run()
frozen_values_after = session.run(frozen_vars)
train_values_after = session.run(train_vars)
# Assert.
self.assertTrue(
model._original_hparams.shared_embedding_and_softmax_weights)
self.assertFalse(model.hparams.shared_embedding_and_softmax_weights)
self.assertTrue(model.hparams.shared_embedding)
for name, before, after in zip(frozen_names, frozen_values_before,
frozen_values_after):
assert_with_message(
self.assertAllClose, before, after,
"%s should be frozen, but changed after training." % name)
for name, before, after in zip(train_names, train_values_before,
train_values_after):
assert_with_message(
self.assertNotAllClose, before, after,
"%s should be trainable, but did not change after training." % name)
def testAllWeightsTrainableByDefault(self):
# Arrange.
model, features = get_model(
transformer.transformer_tiny(), num_decoder_layers=3)
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
loss = tf.reduce_mean(loss)
apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss)
var_names = [
"evolved_transformer/symbol_modality_10_4/shared/weights_0:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_1:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_2:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_3:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_4:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_5:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_6:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_7:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_8:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_9:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_10:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_11:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_12:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_13:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_14:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_15:0",
"evolved_transformer/body/target_space_embedding/kernel:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/kernel:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense/bias:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/kernel:0",
"evolved_transformer/body/encoder/layer_0/gated_linear_unit/dense_1/bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/dense/kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/dense/bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/standard_conv_3x1/bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/depthwise_kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/pointwise_kernel:0",
"evolved_transformer/body/encoder/layer_0/conv_branches/separable_conv_9x1/bias:0",
"evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/encoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense/kernel:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense/bias:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/encoder/layer_0/dense_layers/dense_1/bias:0",
"evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/16_head_self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv11x1/bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_1/bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_0/conv_branches/separable_conv_7x1_2/bias:0",
"evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_0/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense/kernel:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense/bias:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/decoder/layer_0/dense_layers/dense_1/bias:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/16_head_self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv11x1/bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_1/bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_1/conv_branches/separable_conv_7x1_2/bias:0",
"evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_1/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense/kernel:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense/bias:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/decoder/layer_1/dense_layers/dense_1/bias:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/16_head_self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/first_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv11x1/bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_1/bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/depthwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/pointwise_kernel:0",
"evolved_transformer/body/decoder/layer_2/conv_branches/separable_conv_7x1_2/bias:0",
"evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/self_attention/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/self_attention/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/q/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/k/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/v/kernel:0",
"evolved_transformer/body/decoder/layer_2/second_attend_to_encoder/multihead_attention/output_transform/kernel:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense/kernel:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense/bias:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/layer_prepostprocess_1/layer_norm/layer_norm_bias:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/kernel:0",
"evolved_transformer/body/decoder/layer_2/dense_layers/dense_1/bias:0",
"evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_scale:0",
"evolved_transformer/body/decoder/layer_prepostprocess/layer_norm/layer_norm_bias:0",
]
variables = get_vars(var_names)
print_vars()
# Act.
with self.test_session() as session:
tf.global_variables_initializer().run()
values_before = session.run(variables)
for _ in range(10): # Arbitrary number of training steps.
apply_grad.run()
values_after = session.run(variables)
# Assert.
self.assertTrue(
model._original_hparams.shared_embedding_and_softmax_weights)
self.assertTrue(model.hparams.shared_embedding_and_softmax_weights)
self.assertFalse(model.hparams.shared_embedding)
self.assertSameElements(var_names,
[var.name for var in tf.trainable_variables()])
empty_vars = {
"evolved_transformer/symbol_modality_10_4/shared/weights_10:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_11:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_12:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_13:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_14:0",
"evolved_transformer/symbol_modality_10_4/shared/weights_15:0"
}
for name, before, after in zip(var_names, values_before, values_after):
if name in empty_vars:
self.assertEqual(before.size, after.size)
self.assertEqual(before.size, 0)
else:
assert_with_message(
self.assertNotAllClose, before, after,
"%s should be trainable, but did not change after training." % name)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/image_transformer.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""image generation with transformer (attention).
encoder: [Self-Attention, Feed-forward] x n
decoder: [Self-Attention, Source-Target-Attention, Feed-forward] x n
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import modalities
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
@registry.register_model
class Imagetransformer(t2t_model.T2TModel):
"""Conditional image generation with attention. See file docstring.
The model admits either a Categorical or discretized mixture of logistic
distributions (DMOL) as the likelihood. When using DMOL for training, double
check that the evaluation metrics also use it.
"""
def body(self, features):
hparams = copy.copy(self._hparams)
targets = features["targets"]
if (hparams.likelihood == cia.DistributionType.DMOL and
hparams.num_channels != 1):
raise ValueError("When using DMOL for the likelihood, bottom function "
" must be identity and num_channels must be 1.")
if (not tf.get_variable_scope().reuse and
hparams.mode != tf_estimator.ModeKeys.PREDICT):
tf.summary.image("targets", tf.to_float(targets), max_outputs=1)
# Extra losses list if we want to use moe.
losses = []
# Prepare decoder inputs and bias.
decoder_input, rows, cols = cia.prepare_decoder(targets, hparams)
# Add class label to decoder input.
if not hparams.unconditional:
inputs = features["inputs"]
decoder_input += tf.reshape(
inputs,
[common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size])
decoder_output = cia.transformer_decoder_layers(
decoder_input,
None,
hparams.num_decoder_layers or hparams.num_hidden_layers,
hparams,
attention_type=hparams.dec_attention_type,
losses=losses,
name="decoder")
output = cia.create_output(decoder_output, rows, cols, targets, hparams)
if losses:
return output, {"extra_loss": tf.add_n(losses)}
else:
return output
def loss(self, logits, features):
if self._hparams.likelihood == cia.DistributionType.DMOL:
return common_layers.dml_loss(logits, features["targets"])
return super(Imagetransformer, self).loss(logits, features)
def sample(self, features):
"""Run the model and extract samples.
Args:
features: an map of string to `Tensor`.
Returns:
samples: an integer `Tensor`.
logits: a list of `Tensor`s, one per datashard.
losses: a dictionary: {loss-name (string): floating point `Scalar`}.
"""
if self._hparams.likelihood == cia.DistributionType.DMOL:
logits, losses = self(features) # pylint: disable=not-callable
samples = common_layers.sample_from_discretized_mix_logistic(
logits, seed=None)
return samples, logits, losses
return super(Imagetransformer, self).sample(features)
def _slow_greedy_infer(self, features, decode_length):
"""A slow greedy inference method.
Quadratic time in decode_length.
Args:
features: an map of string to `Tensor`
decode_length: an integer. How many additional timesteps to decode.
Returns:
samples: an integer `Tensor`.
logits: `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
losses: a dictionary: {loss-name (string): floating point `Scalar`}
"""
if self._hparams.likelihood == cia.DistributionType.DMOL:
raise NotImplementedError("Decoding is not currently available for DMOL.")
return super(Imagetransformer, self)._slow_greedy_infer(features,
decode_length)
@registry.register_model
class ImagetransformerMoe(t2t_model.T2TModel):
"""Conditional image generation with attention and MoE."""
@staticmethod
def use_body_sharded():
return True
def body_sharded(self, sharded_features):
dp = self._data_parallelism
hparams = copy.copy(self._hparams)
inputs = sharded_features["inputs"]
targets = sharded_features["targets"]
# Determine attention type and padding from hparams.
q_padding, kv_padding = "VALID", "VALID"
if hparams.q_filter_width > 1:
q_padding = "LEFT"
if hparams.kv_filter_width > 1:
kv_padding = "LEFT"
# Prepare decoder inputs and bias.
decoder_input, rows, cols = dp(cia.prepare_decoder_inputs,
inputs, targets, hparams)
# Run decoder.
# TODO(nikip): Use q_padding and kv_padding
del q_padding, kv_padding
decoder_output, extra_loss = cia.transformer_layers_sharded(
dp,
self._ps_devices,
decoder_input,
hparams.num_hidden_layers,
hparams,
self_attention_bias=None,
enc_output=None,
attention_type=hparams.dec_attention_type,
name="decoder")
output = dp(cia.create_output, decoder_output, rows, cols, targets, hparams)
return output, extra_loss
@registry.register_hparams
def image_transformer_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.hidden_size = 512
hparams.batch_size = 4
hparams.max_length = 3075
hparams.dropout = 0.0
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
hparams.optimizer_adam_epsilon = 1e-9
hparams.learning_rate_decay_scheme = "noam"
hparams.learning_rate = 0.1
hparams.learning_rate_warmup_steps = 4000
hparams.initializer_gain = 0.2
hparams.num_hidden_layers = 6
hparams.initializer = "uniform_unit_scaling"
hparams.weight_decay = 0.0
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.98
hparams.label_smoothing = 0.0
hparams.bottom["targets"] = modalities.image_channel_embeddings_bottom
hparams.top["targets"] = modalities.identity_top
hparams.norm_type = "layer"
hparams.layer_prepostprocess_dropout = 0.0
hparams.add_hparam("filter_size", 512) # Add new ones like this.
# attention-related flags
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("attention_key_channels", 0)
hparams.add_hparam("attention_value_channels", 0)
hparams.add_hparam("ffn_layer", "conv_hidden_relu")
# All hyperparameters ending in "dropout" are automatically set to 0.0
# when not in training mode.
hparams.add_hparam("attention_dropout", 0.0)
hparams.add_hparam("relu_dropout", 0.0)
hparams.add_hparam("pos", "timing") # timing, none
hparams.add_hparam("nbr_decoder_problems", 1)
hparams.add_hparam("num_output_layers", 3)
hparams.add_hparam("block_size", 1)
# dilated attention based flags
hparams.add_hparam("gap_sizes", [2, 4, 8, 16, 32, 64, 2, 4, 8, 16, 32, 64])
# image size related flags
# assuming that the image has same height and width
hparams.add_hparam("img_len", 32)
hparams.add_hparam("num_channels", 3)
# Local attention params
hparams.add_hparam("local_and_global_att", False)
hparams.add_hparam("block_length", 256)
hparams.add_hparam("block_width", 128)
hparams.add_hparam("num_encoder_layers", 4)
hparams.add_hparam("num_decoder_layers", 12)
hparams.add_hparam("dec_attention_type", cia.AttentionType.LOCAL_1D)
hparams.add_hparam("block_raster_scan", False)
# multipos attention params
hparams.add_hparam("q_filter_width", 1)
hparams.add_hparam("kv_filter_width", 1)
hparams.add_hparam("likelihood", cia.DistributionType.CAT)
hparams.add_hparam("unconditional", False) # unconditional generation
# parameters of discretized mixture of logistics loss from pixel cnn++
hparams.add_hparam("num_mixtures", 10)
# These parameters are only used when ffn_layer=="local_moe_tpu"
hparams.add_hparam("moe_overhead_train", 1.0)
hparams.add_hparam("moe_overhead_eval", 2.0)
hparams.moe_num_experts = 8
hparams.moe_loss_coef = 1e-3
# These parameters are for relative attention
hparams.add_hparam("shared_rel", False) # share relative embeddings
return hparams
@registry.register_hparams
def imagetransformer_base():
hparams = image_transformer_base()
return hparams
@registry.register_hparams
def imagetransformer_cifar10_base():
"""Best config for 2.90 bits/dim on CIFAR10 using cross entropy."""
hparams = image_transformer_base()
hparams.batch_size = 4
hparams.num_heads = 4
hparams.num_decoder_layers = 12
hparams.block_length = 256
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.learning_rate = 0.5
hparams.learning_rate_warmup_steps = 4000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
hparams.unconditional = True
return hparams
@registry.register_hparams
def imagetransformer_cifar10_base_dmol():
"""Best config for 2.90 bits/dim on CIFAR10 using DMOL."""
hparams = image_transformer_base()
hparams.likelihood = cia.DistributionType.DMOL
hparams.num_channels = 1
hparams.bottom["targets"] = modalities.image_channel_compress_targets_bottom
hparams.top["targets"] = modalities.identity_top
hparams.num_heads = 8
hparams.batch_size = 8
hparams.sampling_method = "random"
hparams.layer_preprocess_sequence = "n"
hparams.layer_postprocess_sequence = "da"
hparams.summarize_grads = True
hparams.hidden_size = 256
hparams.filter_size = 512
hparams.attention_key_channels = 512
hparams.attention_value_channels = 512
hparams.num_decoder_layers = 12
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate = 0.1
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.pos = "emb"
hparams.unconditional = True
return hparams
@registry.register_hparams
def imagetransformer_base_tpu():
"""Transformer base params for cifar-10."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.learning_rate = 0.2
hparams.learning_rate_warmup_steps = 6000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_base_imagenet_tpu():
"""Transformer base params for cifar-10."""
hparams = imagetransformer_base_tpu()
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformer_imagenet32_base():
"""Best config for ImageNet-32 with 3.77 bits/dim using cross entropy."""
hparams = imagetransformer_cifar10_base()
hparams.batch_size = 4
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformer_base_rel():
"""Base with relative attention."""
hparams = imagetransformer_base()
hparams.dec_attention_type = cia.AttentionType.RELATIVE_LOCAL_1D
return hparams
@registry.register_hparams
def imagetransformer_sep_channels():
"""separate rgb embeddings."""
hparams = imagetransformer_base()
hparams.num_heads = 4
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.hidden_size = 256
hparams.filter_size = 512
hparams.num_hidden_layers = 6
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_8l():
"""separate rgb embeddings."""
hparams = imagetransformer_base()
hparams.num_heads = 4
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.hidden_size = 256
hparams.filter_size = 256
hparams.num_hidden_layers = 8
hparams.sampling_method = "random"
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_8l_multipos3():
"""separate rgb embeddings."""
hparams = imagetransformer_sep_channels_8l()
hparams.q_filter_width = 3
hparams.kv_filter_width = 3
return hparams
@registry.register_hparams
def imagetransformer_base_8l_8h_big_cond_dr03_dan():
"""big 1d model for conditional image generation.2.99 on cifar10."""
hparams = imagetransformer_sep_channels_8l()
hparams.block_width = 256
hparams.block_length = 256
hparams.hidden_size = 512
hparams.num_heads = 8
hparams.filter_size = 2048
hparams.batch_size = 4
hparams.max_length = 3075
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.num_decoder_layers = 8
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_base_10l_8h_big_uncond_dr03_dan_64():
"""big 1d model for unconditional generation on imagenet."""
hparams = imagetransformer_base_10l_8h_big_cond_dr03_dan()
hparams.unconditional = True
hparams.max_length = 14000
hparams.batch_size = 1
hparams.img_len = 64
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformerpp_sep_channels_8l_8h():
"""separate rgb embeddings."""
hparams = imagetransformer_base()
hparams.likelihood = cia.DistributionType.DMOL
hparams.num_channels = 1
hparams.bottom["targets"] = modalities.image_channel_compress_targets_bottom
hparams.top["targets"] = modalities.identity_top
hparams.num_heads = 8
hparams.batch_size = 4
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.hidden_size = 512
hparams.filter_size = 512
hparams.num_hidden_layers = 8
hparams.sampling_method = "random"
hparams.layer_preprocess_sequence = "n"
hparams.layer_postprocess_sequence = "da"
hparams.summarize_grads = True
hparams.learning_rate = 0.1
return hparams
@registry.register_hparams
def imagetransformerpp_base_8l_8h_big_cond_dr03_dan():
"""big 1d model for conditional image generation.2.99 on cifar10."""
hparams = imagetransformerpp_sep_channels_8l_8h()
hparams.hidden_size = 512
hparams.num_heads = 8
hparams.filter_size = 2048
hparams.batch_size = 4
hparams.max_length = 3075
hparams.layer_prepostprocess_dropout = 0.3
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.summarize_grads = True
hparams.learning_rate = 0.01
return hparams
@registry.register_hparams
def imagetransformerpp_base_8l_8h_big_cond_dr03_dan_a():
hparams = imagetransformerpp_base_8l_8h_big_cond_dr03_dan()
hparams.learning_rate = 0.1
return hparams
@registry.register_hparams
def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan():
hparams = imagetransformerpp_base_8l_8h_big_cond_dr03_dan_a()
hparams.unconditional = True
hparams.num_decoder_layers = 10
return hparams
@registry.register_hparams
def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_a():
hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan()
hparams.learning_rate = 0.01
return hparams
@registry.register_hparams
def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_b():
hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan()
hparams.learning_rate = 0.1
hparams.hidden_size = 256
hparams.attention_key_channels = 512
hparams.attention_value_channels = 512
hparams.filter_size = 1024
return hparams
@registry.register_hparams
def imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g():
hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_b()
hparams.filter_size = 512
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate = 0.1
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.pos = "emb"
return hparams
@registry.register_hparams
def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_k():
hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g()
hparams.num_decoder_layers = 12
return hparams
@registry.register_hparams
def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_l():
hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g()
hparams.num_decoder_layers = 12
hparams.clip_grad_norm = 40.
return hparams
@registry.register_hparams
def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m():
hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_k()
hparams.batch_size = 8
return hparams
@registry.register_hparams
def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_rel():
hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_k()
hparams.batch_size = 8
hparams.dec_attention_type = cia.AttentionType.RELATIVE_LOCAL_1D
return hparams
@registry.register_hparams
def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_relsh():
hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_rel()
hparams.shared_rel = True
return hparams
@registry.register_hparams
def imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_p():
"""Gets to 2.92 in just under 4 days on 8 p100s."""
hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_l()
hparams.num_decoder_layers = 14
hparams.batch_size = 8
hparams.layer_prepostprocess_dropout = 0.2
return hparams
@registry.register_hparams
def imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m_bs1():
"""For 128x128."""
# TODO(trandustin): why are these running? max_length and img_len not set
# 256x256 was also training without setting max_length
hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_m()
hparams.batch_size = 1
return hparams
@registry.register_hparams
def imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_p_bs1():
"""For 128x128."""
hparams = imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_p()
hparams.batch_size = 1
return hparams
@registry.register_hparams
def imagetransformerpp_base_5l_8h_big_uncond_dr00_dan_g_bs1():
"""For 256x256."""
hparams = imagetransformerpp_base_10l_8h_big_uncond_dr03_dan_g()
# TODO(trandustin): I forgot to set this in the runs! Maybe it's not used in
# image transformer training implementation?
# hparams.img_len = 256
hparams.max_length = 66000 # allow for 256x256
hparams.batch_size = 1
hparams.num_decoder_layers = 5
hparams.hidden_size = 128
hparams.filter_size = 128
hparams.attention_key_channels = 64
hparams.attention_value_channels = 64
hparams.layer_prepostprocess_dropout = 0.0
return hparams
@registry.register_hparams
def imagetransformerpp_base_5l_8h_dr00_dan_g_bs1_adafactor():
"""For 256x256."""
hparams = imagetransformerpp_base_5l_8h_big_uncond_dr00_dan_g_bs1()
# Use Adafactor which uses less memory than Adam, and its recommendations.
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay"
return hparams
@registry.register_hparams
def imagetransformerpp_base_6l_8h_dr00_dan_g_bs1_adafactor():
"""For 256x256."""
hparams = imagetransformerpp_base_5l_8h_dr00_dan_g_bs1_adafactor()
hparams.num_decoder_layers = 6
return hparams
@registry.register_hparams
def imagetransformerpp_base_14l_8h_big_uncond_dr03_dan_eval():
"""Gets to 2.92 in just under 4 days on 8 p100s."""
hparams = imagetransformerpp_base_12l_8h_big_uncond_dr03_dan_l()
hparams.num_decoder_layers = 14
hparams.batch_size = 8
# hparams.layer_prepostprocess_dropout = 0.2
return hparams
@registry.register_hparams
def imagetransformer_base_8l_8h_big_cond_dr03_dan_128():
hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan()
hparams.block_width = 128
hparams.block_length = 128
return hparams
@registry.register_hparams
def imagetransformer_base_10l_8h_big_cond_dr03_dan():
"""Best conditional Cifar10 gen param."""
hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan()
hparams.num_decoder_layers = 10
return hparams
@registry.register_hparams
def imagetransformer_base_10l_8h_big_uncond_dr03_dan():
"""Best unconditional Cifar10 gen param."""
hparams = imagetransformer_base_10l_8h_big_cond_dr03_dan()
hparams.num_decoder_layers = 10
return hparams
@registry.register_hparams
def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated():
"""Dilated hparams."""
hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan()
hparams.gap_sizes = [0, 16, 64, 0, 16, 64, 128, 0]
hparams.dec_attention_type = cia.AttentionType.DILATED
hparams.block_length = 128
hparams.block_width = 128
hparams.add_hparam("num_memory_blocks", 1)
return hparams
@registry.register_hparams
def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated_b():
"""Dilated hparams."""
hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated()
hparams.block_width = 64
hparams.num_memory_blocks = 2
return hparams
@registry.register_hparams
def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated_c():
"""Dilated hparams."""
hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated()
hparams.block_width = 32
hparams.num_memory_blocks = 4
return hparams
@registry.register_hparams
def imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated_d():
"""Dilated hparams."""
hparams = imagetransformer_base_8l_8h_big_cond_dr03_dan_dilated()
hparams.gap_sizes = [0, 16, 64, 16, 64, 128, 256, 0]
return hparams
@registry.register_hparams
def imagetransformer_base_12l_8h_big():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_sep_channels_8l_8h()
hparams.filter_size = 1024
hparams.num_decoder_layers = 12
hparams.batch_size = 1
hparams.hidden_size = 512
hparams.learning_rate_warmup_steps = 4000
hparams.sampling_method = "random"
hparams.beam_size = 1
hparams.block_width = 256
return hparams
@registry.register_hparams
def imagetransformer1d_base_8l_64by64():
"""hparams fo 12 layer big 1d model for imagenet 64x64."""
hparams = image_transformer_base()
hparams.num_heads = 8
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.num_decoder_layers = 8
hparams.batch_size = 1
hparams.block_length = 512
hparams.block_width = 768
hparams.layer_prepostprocess_dropout = 0.1
hparams.max_length = 14000
hparams.unconditional = int(False)
return hparams
@registry.register_hparams
def imagetransformer1d_base_12l_64by64():
"""hparams fo 12 layer big 1d model for imagenet 64x64."""
hparams = image_transformer_base()
hparams.num_heads = 8
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.num_decoder_layers = 12
hparams.batch_size = 1
hparams.block_length = 512
hparams.block_width = 768
hparams.layer_prepostprocess_dropout = 0.1
hparams.max_length = 14000
hparams.unconditional = int(False)
return hparams
@registry.register_hparams
def imagetransformer_base_14l_8h_big():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_12l_8h_big()
hparams.num_decoder_layers = 14
return hparams
@registry.register_hparams
def imagetransformer_base_14l_8h_big_dr01():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_14l_8h_big()
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformer_base_12l_8h_big_uncond():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_12l_8h_big()
hparams.unconditional = True
return hparams
@registry.register_hparams
def imagetransformer_base_14l_8h_big_uncond():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_12l_8h_big_uncond()
hparams.num_decoder_layers = 14
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_12l_16h_imagenet_large():
"""separate rgb embeddings."""
hparams = imagetransformer_sep_channels_8l_8h()
hparams.num_hidden_layers = 12
hparams.batch_size = 1
hparams.filter_size = 2048
hparams.num_heads = 16
hparams.learning_rate_warmup_steps = 16000
hparams.sampling_method = "random"
hparams.learning_rate = 0.1
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_16l_16h_imgnet_lrg_loc():
"""separate rgb embeddings."""
hparams = imagetransformer_sep_channels_12l_16h_imagenet_large()
hparams.num_hidden_layers = 16
hparams.local_attention = True
hparams.batch_size = 1
hparams.block_length = 256
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_16l_16h_imgnet_lrg_loc_128():
"""separate rgb embeddings."""
hparams = imagetransformer_sep_channels_12l_16h_imagenet_large()
hparams.num_hidden_layers = 16
hparams.local_attention = True
hparams.batch_size = 1
hparams.block_length = 128
return hparams
@registry.register_hparams
def imagetransformer_sep_output_channels_8l_local_and_global_att():
"""separate rgb embeddings."""
hparams = imagetransformer_sep_channels_8l()
hparams.sampling_method = "random"
hparams.local_and_global_att = True
return hparams
@registry.register_hparams
def imagetransformer_base_10l_16h_big_uncond_dr01_imgnet():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_14l_8h_big_dr01()
# num_hidden_layers
hparams.num_decoder_layers = 10
hparams.num_heads = 16
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.batch_size = 1
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformer_base_10l_16h_big_dr01_imgnet():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_14l_8h_big_dr01()
# num_hidden_layers
hparams.num_decoder_layers = 10
hparams.num_heads = 16
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.batch_size = 1
hparams.unconditional = False
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_8l_8h():
"""separate rgb embeddings."""
hparams = imagetransformer_base()
hparams.num_heads = 8
hparams.batch_size = 1
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.hidden_size = 512
hparams.filter_size = 512
hparams.num_hidden_layers = 8
hparams.sampling_method = "random"
return hparams
@registry.register_hparams
def imagetransformer_sep_channels_8l_8h_local_and_global_att():
"""separate rgb embeddings."""
hparams = imagetransformer_sep_channels_8l_8h()
hparams.num_heads = 8
hparams.batch_size = 1
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.hidden_size = 256
hparams.filter_size = 256
hparams.num_hidden_layers = 4
hparams.sampling_method = "random"
hparams.local_and_global_att = True
return hparams
@registry.register_hparams
def imagetransformer_bas8l_8h_big_uncond_dr03_imgnet():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_14l_8h_big_dr01()
# num_hidden_layers
hparams.num_decoder_layers = 8
hparams.num_heads = 8
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_tiny():
hparams = imagetransformer_base()
hparams.num_decoder_layers = 2
hparams.hidden_size = 64
hparams.batch_size = 1
hparams.unconditional = True
hparams.max_length = 66000 # allow for 256x256
return hparams
@registry.register_hparams
def imagetransformerpp_tiny():
hparams = imagetransformer_tiny()
hparams.likelihood = cia.DistributionType.DMOL
hparams.num_channels = 1
hparams.bottom["targets"] = modalities.image_channel_compress_targets_bottom
hparams.top["targets"] = modalities.identity_top
return hparams
@registry.register_hparams
def imagetransformer_tiny_tpu():
hparams = imagetransformer_tiny()
update_hparams_for_tpu(hparams)
hparams.num_hidden_layers = 2
hparams.hidden_size = 16
hparams.batch_size = 2
hparams.num_heads = 2
return hparams
@registry.register_hparams
def imagetransformer_base_10l_16h_big_dr01_moe_imgnet():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_base_10l_16h_big_dr01_imgnet()
hparams.initializer = "orthogonal"
hparams.learning_rate_warmup_steps = 16000
hparams.add_hparam("moe_layers_decoder", "2,7") # Which layer is MoE.
hparams.moe_hidden_sizes = "4096" # Hidden layer sizes (comma-separated).
hparams.moe_num_experts = 64 # Number of experts in each MoE layer.
hparams.moe_k = 4 # How many experts to use per batch element (try 2 or 4).
hparams.moe_loss_coef = 3e-2 # MoE loss coefficient (1e-2 is usually ok).
hparams.scheduled_sampling_prob = 0.1
hparams.scheduled_sampling_warmup_steps = 200000
return hparams
@registry.register_hparams
def imagetransformer_moe_tiny():
"""Set of hyperparameters for a very small imagetransformer with MoE."""
hparams = imagetransformer_tiny()
hparams.hidden_size = 64
hparams.batch_size = 1
hparams.num_hidden_layers = 3
hparams.dec_attention_type = cia.AttentionType.MOE_LOCAL_1D
hparams.add_hparam("moe_layers_decoder", "1") # Which layer is MoE.
hparams.moe_hidden_sizes = "1024" # Hidden layer sizes (comma-separated).
hparams.moe_num_experts = 16 # Number of experts in each MoE layer.
hparams.moe_k = 2 # How many experts to use per batch element (try 2 or 4).
hparams.moe_loss_coef = 1e-2 # MoE loss coefficient (1e-2 is usually ok).
return hparams
def update_hparams_for_tpu(hparams):
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay"
hparams.learning_rate_warmup_steps = 6000
hparams.batch_size = 4
@registry.register_hparams
def imagetransformer_sep_channels_8l_tpu():
"""Hparams for training imagetransformer on tpu."""
hparams = imagetransformer_sep_channels_8l()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.shared_embedding_and_softmax_weights = False
return hparams
@registry.register_hparams
def imagetransformer_b10l_4h_big_uncond_dr03_tpu():
"""Small model for tpu cifar 10."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 10
hparams.block_length = 128
hparams.hidden_size = 512
hparams.filter_size = 1024
hparams.learning_rate = 0.2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
return hparams
@registry.register_hparams
def imagetransformer_b10l_dr03_moe_tpu():
"""Moe tpu params."""
hparams = imagetransformer_b10l_4h_big_uncond_dr03_tpu()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 10
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.ffn_layer = "local_moe_tpu"
return hparams
@registry.register_hparams
def imagetransformer_b10l_4h_big_uncond_dr03_lr025_tpu():
"""TPU related small model."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 10
hparams.learning_rate = 0.25
hparams.learning_rate_warmup_steps = 8000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
# hparams.unconditional = True
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_big_uncond_dr03_tpu():
"""TPU 12 layer model."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.hidden_size = 512
hparams.filter_size = 1024
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_big_uncond_dr03_lr025_tpu():
hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu()
update_hparams_for_tpu(hparams)
hparams.learning_rate = 0.25
hparams.learning_rate_warmup_steps = 5000
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_b256_uncond_dr03_tpu():
"""works very well on 4x4."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 256
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.learning_rate = 0.5
hparams.learning_rate_warmup_steps = 4000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
hparams.unconditional = True
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_b256_uncond_dr03_rel_tpu():
"""works very well on 4x4."""
hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu()
hparams.shared_rel = True
hparams.dec_attention_type = cia.AttentionType.RELATIVE_LOCAL_1D
return hparams
@registry.register_ranged_hparams
def imagetransformer_cifar_tpu_range(rhp):
"""Range of hyperparameters for vizier."""
# After starting from base, set intervals for some parameters.
rhp.set_float("learning_rate", 0.01, 1.0, scale=rhp.LOG_SCALE)
rhp.set_discrete("num_decoder_layers", [8, 10, 12, 14, 16])
rhp.set_discrete("hidden_size", [256, 512, 1024])
rhp.set_discrete("block_length", [128, 256, 512])
rhp.set_categorical("dec_attention_type", [
cia.AttentionType.RELATIVE_LOCAL_1D, cia.AttentionType.LOCAL_1D])
@registry.register_hparams
def imagetransformer_b12l_4h_b128_h512_uncond_dr03_tpu():
"""TPU related big model."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.learning_rate = 0.2
hparams.learning_rate_warmup_steps = 6000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_b128_h512_uncond_dr01_im():
"""TPU related imagenet model."""
hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay"
hparams.learning_rate_warmup_steps = 6000
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_uncond_dr03_tpu():
"""TPU related small model."""
hparams = imagetransformer_b12l_4h_b256_uncond_dr03_tpu()
hparams.learning_rate = 0.2
hparams.learning_rate_warmup_steps = 4000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_b12l_4h_b128_uncond_dr03_tpu():
"""TPU config for cifar 10."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 2
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.hidden_size = 256
hparams.filter_size = 2048
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.1
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay"
hparams.learning_rate_warmup_steps = 10000
return hparams
@registry.register_hparams
def imagetransformer_b12l_8h_b256_uncond_dr03_tpu():
"""TPU related 12 layer 8 heads model."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 2
hparams.num_heads = 8 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 256
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
return hparams
@registry.register_hparams
def imagetransformer_b10l_4h_big_uncond_dr01_tpu():
"""big 1d model for conditional image generation."""
hparams = imagetransformer_b12l_4h_big_uncond_dr03_tpu()
# num_hidden_layers
hparams.num_decoder_layers = 10
hparams.num_heads = 4
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.batch_size = 1
hparams.layer_prepostprocess_dropout = 0.1
return hparams
================================================
FILE: tensor2tensor/models/image_transformer_2d.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""image generation with transformer (attention).
encoder: [Self-Attention, Feed-forward] x n
decoder: [Self-Attention, Source-Target-Attention, Feed-forward] x n
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_image_attention as cia
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import modalities
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
@registry.register_model
class Imagetransformer2d(t2t_model.T2TModel):
"""Conditional image generation with attention. See file docstring."""
def body(self, features):
hparams = copy.copy(self._hparams)
inputs = features["inputs"]
targets = features["targets"]
targets_shape = common_layers.shape_list(targets)
if not (tf.get_variable_scope().reuse or
hparams.mode == tf_estimator.ModeKeys.PREDICT):
tf.summary.image("targets", targets, max_outputs=1)
decoder_input, rows, cols = cia.prepare_decoder(
targets, hparams)
# Add class label to decoder input.
if not hparams.unconditional:
decoder_input += tf.reshape(inputs,
[targets_shape[0], 1, 1, hparams.hidden_size])
decoder_output = cia.transformer_decoder_layers(
decoder_input, None,
hparams.num_decoder_layers,
hparams,
attention_type=hparams.dec_attention_type,
name="decoder")
output = cia.create_output(decoder_output, rows, cols, targets, hparams)
return output
@registry.register_model
class Img2imgTransformer(t2t_model.T2TModel):
"""Image 2 Image transformer net."""
def body(self, features):
hparams = copy.copy(self._hparams)
targets = features["targets"]
inputs = features["inputs"]
if not (tf.get_variable_scope().reuse or
hparams.mode == tf_estimator.ModeKeys.PREDICT):
tf.summary.image("inputs", inputs, max_outputs=1)
tf.summary.image("targets", targets, max_outputs=1)
encoder_input = cia.prepare_encoder(inputs, hparams)
encoder_output = cia.transformer_encoder_layers(
encoder_input,
hparams.num_encoder_layers,
hparams,
attention_type=hparams.enc_attention_type,
name="encoder")
decoder_input, rows, cols = cia.prepare_decoder(
targets, hparams)
decoder_output = cia.transformer_decoder_layers(
decoder_input,
encoder_output,
hparams.num_decoder_layers,
hparams,
attention_type=hparams.dec_attention_type,
name="decoder")
output = cia.create_output(decoder_output, rows, cols, targets, hparams)
return output
@registry.register_model
class Img2imgTransformerBlockParallel(t2t_model.T2TModel):
"""Image-to-image transformer predicting blocks of the output in parallel."""
def body(self, features):
assert self._hparams.block_size > 0
assert not common_layers.is_xla_compiled()
hparams = copy.copy(self._hparams)
targets = features["targets"]
inputs = features["inputs"]
if not (tf.get_variable_scope().reuse or
hparams.mode == tf_estimator.ModeKeys.PREDICT):
tf.summary.image("inputs", inputs, max_outputs=1)
tf.summary.image("targets", targets, max_outputs=1)
encoder_input = cia.prepare_encoder(inputs, hparams)
encoder_output = cia.transformer_encoder_layers(
encoder_input,
hparams.num_encoder_layers,
hparams,
attention_type=hparams.enc_attention_type,
name="encoder")
decoder_input, rows, cols = cia.prepare_decoder(
targets, hparams)
decoder_output = cia.transformer_decoder_layers(
decoder_input,
encoder_output,
hparams.num_decoder_layers,
hparams,
attention_type=hparams.dec_attention_type,
name="decoder")
assert not isinstance(decoder_output, tuple)
assert len(decoder_output.shape) == 4
relu_dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(self._hparams, "relu_dropout_broadcast_dims", "")))
with tf.variable_scope("block_size_%d" % self._hparams.block_size):
tf.logging.info("Using block_size %d", self._hparams.block_size)
block_output = common_layers.dense_relu_dense(
decoder_output,
self._hparams.block_size * self._hparams.filter_size,
self._hparams.block_size * self._hparams.hidden_size,
dropout=self._hparams.relu_dropout,
dropout_broadcast_dims=relu_dropout_broadcast_dims)
batch_size, rows, cols = common_layers.shape_list(decoder_output)[:3]
decoder_output = tf.reshape(decoder_output, [
batch_size,
rows,
cols,
1,
self._hparams.hidden_size
])
block_output = tf.reshape(block_output, [
batch_size,
rows,
cols,
self._hparams.block_size,
self._hparams.hidden_size
])
block_output = common_layers.layer_postprocess(
decoder_output, block_output, self._hparams)
return block_output
def top(self, body_output, features):
assert self._hparams.block_size > 0
train_or_eval = (
self._hparams.mode == tf_estimator.ModeKeys.TRAIN or
self._hparams.mode == tf_estimator.ModeKeys.EVAL)
if train_or_eval:
if self._hparams.mode == tf_estimator.ModeKeys.TRAIN:
features["block_index"] = tf.random_uniform(
shape=[], minval=0, maxval=self._hparams.block_size, dtype=tf.int64)
else:
features["block_index"] = 0
body_output = body_output[:, :, :, features["block_index"], :]
decoded_image = tf.layers.dense(
body_output, 256, use_bias=True, activation=None, name="output_conv")
assert len(features["targets"].shape) == 4
targets_shape = common_layers.shape_list(features["targets"])
if train_or_eval:
output = tf.reshape(decoded_image, targets_shape + [256])
else:
output = tf.reshape(decoded_image, [
targets_shape[0], -1, self._hparams.block_size, 1, 256])
output = output[:, :targets_shape[1], :, :, :]
return output
def loss(self, logits, features):
assert self._hparams.block_size > 0
if self._hparams.mode == tf_estimator.ModeKeys.PREDICT:
return 0.0
def shift_left_2d(x, k):
return tf.pad(x, [[0, 0], [0, k]])[:, k:]
def shift_left_4d_raster_scan(x, k):
batch_size = common_layers.shape_list(x)[0]
return tf.reshape(
shift_left_2d(tf.reshape(x, [batch_size, -1]), k), tf.shape(x))
targets = features["targets"]
assert len(targets.shape) == 4
targets = tf.stack([
shift_left_4d_raster_scan(targets, i)
for i in range(self._hparams.block_size)
], axis=4)
if (self._hparams.mode == tf_estimator.ModeKeys.TRAIN or
self._hparams.mode == tf_estimator.ModeKeys.EVAL):
assert "block_index" in features
targets = targets[:, :, :, :, features["block_index"]]
features["targets"] = targets
loss = super(Img2imgTransformerBlockParallel, self).loss(logits, features)
if self._hparams.mode == tf_estimator.ModeKeys.TRAIN:
k = features["block_index"]
loss_num, loss_den = loss
loss_val = loss_num / loss_den
for i in range(self._hparams.block_size):
# Hack: if you report a loss of NaN, TensorBoard will plot a point at
# the previous value without a connecting line. This is used here to
# separate out the training losses by block index.
one_or_nan = tf.cond(tf.equal(k, i), lambda: 1.0, lambda: float("nan"))
tf.summary.scalar(
"block_index_%d" % i, one_or_nan * loss_val, family="losses")
return loss
def _greedy_infer(self, features, decode_length, use_tpu=False):
assert not use_tpu
return self._slow_greedy_infer_guess_and_check(features, decode_length)
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
raise NotImplementedError
def _slow_greedy_infer_guess_and_check(self, features, decode_length):
assert self._hparams.block_size > 0
assert self._hparams.force_full_predict
assert self._hparams.sampling_method == "argmax"
assert self._decode_hparams.batch_size == 1
assert self._decode_hparams.block_size > 0
assert self._decode_hparams.block_size <= self._hparams.block_size
assert (
(self._decode_hparams.guess_and_check_top_k > 0) +
(self._decode_hparams.guess_and_check_epsilon >= 0) == 1)
inputs_old = features["inputs"]
assert "targets" not in features
assert len(features["inputs"].shape) in [3, 4]
if len(features["inputs"].shape) < 4:
features["inputs"] = tf.expand_dims(features["inputs"], 2)
block_size = self._decode_hparams.block_size
decode_length += tf.shape(features["inputs"])[1]
def while_exit_cond(result, length): # pylint: disable=unused-argument
return length < decode_length
def infer_step(result, length):
"""Inference step."""
def print_info(samples, result, length, new_length):
tf.logging.info(
"length=%s new_length=%s length_diff=%s samples-result=%s",
length,
new_length,
new_length - length,
np.array_str(
samples[0, -block_size-1:-1, 0, 0] -
result[0, -block_size:, 0, 0]
).replace("\n", ""),
)
features["targets"] = tf.pad(result, [[0, 0], [0, 1], [0, 0], [0, 0]])
samples, logits, losses = self.sample(features) # pylint: disable=unused-variable
_, top_k_indices = tf.nn.top_k(
logits[:, :-1, :1, :, :],
k=self._decode_hparams.guess_and_check_top_k)
in_top_k = tf.reduce_any(
tf.equal(tf.to_int64(top_k_indices), tf.expand_dims(result, 4)),
axis=4)
within_epsilon = tf.less_equal(
tf.abs(result - samples[:, :-1, :1, :]),
self._decode_hparams.guess_and_check_epsilon)
if self._decode_hparams.guess_and_check_top_k:
tf.logging.info(
"Using guess_and_check_top_k=%s",
self._decode_hparams.guess_and_check_top_k)
correct = in_top_k
else:
tf.logging.info(
"Using guess_and_check_epsilon=%s",
self._decode_hparams.guess_and_check_epsilon)
correct = within_epsilon
correct_cumsum = tf.cumsum(tf.to_int32(correct), axis=1)
perfect_cumsum = 1 + tf.range(tf.shape(correct)[1])
for axis in [0, 2, 3]:
perfect_cumsum = tf.expand_dims(perfect_cumsum, axis=axis)
new_length = tf.reduce_sum(
tf.to_int32(tf.equal(correct_cumsum, perfect_cumsum)), axis=1)
new_length = tf.squeeze(new_length, axis=[0, 1, 2])
new_length = tf.minimum(new_length, decode_length)
new_result = tf.concat([
result[:, :new_length, :, :],
tf.reshape(
samples[:, new_length, :block_size, :], [1, block_size, 1, 1])
], axis=1)
with tf.control_dependencies([
tf.py_func(print_info, [samples, result, length, new_length], [])
]):
new_result = tf.identity(new_result)
return new_result, new_length
result = tf.zeros((1, 0, 1, 1), dtype=tf.int64)
length = tf.squeeze(tf.zeros(1, dtype=tf.int32))
result, length = tf.while_loop(
while_exit_cond,
infer_step,
[result, length],
shape_invariants=[
tf.TensorShape([1, None, 1, 1]),
tf.TensorShape([]),
],
back_prop=False,
parallel_iterations=1)
result = result[:, :length, :, :]
features["inputs"] = inputs_old
return {
"outputs": result,
"scores": None,
}
@registry.register_hparams
def image_transformer2d_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.hidden_size = 512
hparams.batch_size = 1
hparams.max_length = 256
hparams.dropout = 0.0
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
hparams.optimizer_adam_epsilon = 1e-9
hparams.learning_rate_decay_scheme = "noam"
hparams.learning_rate = 0.1
hparams.learning_rate_warmup_steps = 4000
hparams.initializer_gain = 0.2
hparams.initializer = "uniform_unit_scaling"
hparams.weight_decay = 0.0
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.98
hparams.label_smoothing = 0.0
hparams.bottom["targets"] = modalities.make_targets_bottom(
modalities.image_channel_embeddings_bottom)
hparams.top["targets"] = modalities.identity_top
hparams.norm_type = "layer"
hparams.layer_prepostprocess_dropout = 0.0
hparams.add_hparam("filter_size", 512) # Add new ones like this.
# attention-related flags
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("attention_key_channels", 0)
hparams.add_hparam("attention_value_channels", 0)
hparams.add_hparam("ffn_layer", "conv_hidden_relu")
# All hyperparameters ending in "dropout" are automatically set to 0.0
# when not in training mode.
hparams.add_hparam("attention_dropout", 0.0)
hparams.add_hparam("relu_dropout", 0.0)
hparams.add_hparam("pos", "timing") # timing, none
hparams.add_hparam("nbr_decoder_problems", 1)
hparams.add_hparam("num_output_layers", 3)
hparams.add_hparam("block_size", 1)
# image size related flags
# assuming that the image has same height and width
hparams.add_hparam("img_len", 32)
hparams.add_hparam("num_channels", 3)
# Local attention params
hparams.add_hparam("local_and_global_att", False)
hparams.add_hparam("block_length", 256)
hparams.add_hparam("block_width", 128)
# Local 2D attention params
hparams.add_hparam("query_shape", (16, 16))
hparams.add_hparam("memory_flange", (16, 32))
hparams.add_hparam("num_encoder_layers", 4)
hparams.add_hparam("num_decoder_layers", 8)
# attention type related params
hparams.add_hparam("enc_attention_type", cia.AttentionType.GLOBAL)
hparams.add_hparam("dec_attention_type", cia.AttentionType.LOCAL_2D)
hparams.add_hparam("block_raster_scan", False)
# multipos attention params
hparams.add_hparam("q_filter_width", 1)
hparams.add_hparam("kv_filter_width", 1)
hparams.add_hparam("unconditional", False) # unconditional generation
# relative embedding hparams
hparams.add_hparam("shared_rel", False)
return hparams
@registry.register_hparams
def imagetransformer2d_base():
hparams = image_transformer2d_base()
hparams.dec_attention_type = cia.AttentionType.LOCAL_2D
hparams.block_raster_scan = True
return hparams
@registry.register_hparams
def imagetransformer2d_base_8l_8_16():
hparams = image_transformer2d_base()
hparams.num_decoder_layers = 8
hparams.batch_size = 1
hparams.memory_flange = (8, 16)
return hparams
@registry.register_hparams
def imagetransformer2d_base_8l_8_16_ls():
hparams = image_transformer2d_base()
hparams.num_decoder_layers = 8
hparams.label_smoothing = 0.05
hparams.batch_size = 1
hparams.memory_flange = (8, 16)
return hparams
@registry.register_hparams
def imagetransformer2d_base_8l_8_16_big():
hparams = image_transformer2d_base()
hparams.filter_size = 1024
hparams.num_decoder_layers = 8
hparams.batch_size = 1
hparams.memory_flange = (8, 16)
return hparams
@registry.register_hparams
def imagetransformer2d_base_12l_8_16_big():
hparams = image_transformer2d_base()
hparams.filter_size = 1024
hparams.num_decoder_layers = 12
hparams.batch_size = 1
hparams.memory_flange = (8, 16)
hparams.sampling_method = "random"
hparams.beam_size = 1
return hparams
@registry.register_hparams
def imagetransformer2d_base_8l_8_32_big():
"""hparams fo 8 layer big 2d model for cifar 10."""
hparams = image_transformer2d_base()
hparams.num_heads = 16
hparams.hidden_size = 1024
hparams.filter_size = 2048
hparams.num_decoder_layers = 8
hparams.batch_size = 1
hparams.layer_prepostprocess_dropout = 0.3
hparams.query_shape = (8, 16)
hparams.memory_flange = (0, 32)
hparams.unconditional = int(False)
return hparams
@registry.register_hparams
def imagetransformer_base_10l_8h_big_uncond_dr03_dan_64_2d():
"""big 1d model for unconditional generation on imagenet."""
hparams = image_transformer2d_base()
hparams.unconditional = True
hparams.hidden_size = 512
hparams.batch_size = 1
hparams.img_len = 64
hparams.num_heads = 8
hparams.filter_size = 2048
hparams.batch_size = 1
hparams.max_length = 3075
hparams.max_length = 14000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.1
hparams.dec_attention_type = cia.AttentionType.LOCAL_2D
hparams.query_shape = (16, 16)
hparams.memory_flange = (8, 8)
return hparams
@registry.register_hparams
def imagetransformer2d_base_8l_8_64_64by64():
"""hparams fo 12 layer big 2d model for imagenet 64x64."""
hparams = image_transformer2d_base()
hparams.num_heads = 8
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.num_decoder_layers = 8
hparams.batch_size = 1
hparams.layer_prepostprocess_dropout = 0.1
hparams.query_shape = (8, 64)
hparams.memory_flange = (4, 32)
hparams.unconditional = int(False)
hparams.max_length = 14000
return hparams
@registry.register_hparams
def imagetransformer2d_base_12l_8_64_64by64():
"""hparams fo 12 layer big 2d model for imagenet 64x64."""
hparams = image_transformer2d_base()
hparams.num_heads = 8
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.num_decoder_layers = 12
hparams.batch_size = 1
hparams.layer_prepostprocess_dropout = 0.1
hparams.query_shape = (8, 64)
hparams.memory_flange = (4, 32)
hparams.unconditional = int(False)
hparams.max_length = 14000
return hparams
@registry.register_hparams
def imagetransformer2d_base_14l_8_16_big():
hparams = image_transformer2d_base()
hparams.filter_size = 1024
hparams.num_decoder_layers = 14
hparams.batch_size = 1
hparams.memory_flange = (8, 16)
return hparams
@registry.register_hparams
def imagetransformer2d_base_14l_8_16_big_uncond():
hparams = imagetransformer2d_base_14l_8_16_big()
hparams.unconditional = True
return hparams
@registry.register_hparams
def imagetransformer2d_base_8l_8_16_big_16k():
hparams = image_transformer2d_base()
hparams.filter_size = 1024
hparams.num_decoder_layers = 8
hparams.batch_size = 1
hparams.memory_flange = (8, 16)
hparams.learning_rate_warmup_steps = 16000
return hparams
@registry.register_hparams
def img2img_transformer2d_base():
"""Base params for img2img 2d attention."""
hparams = image_transformer2d_base()
# learning related flags
hparams.layer_preprocess_sequence = "n"
hparams.layer_postprocess_sequence = "da"
# This version seems to benefit from a higher learning rate.
hparams.learning_rate = 0.2
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate_warmup_steps = 12000
hparams.filter_size = 2048
hparams.num_encoder_layers = 4
hparams.num_decoder_layers = 8
hparams.bottom["inputs"] = modalities.image_channel_embeddings_bottom
hparams.dec_attention_type = cia.AttentionType.LOCAL_2D
hparams.block_raster_scan = True
return hparams
@registry.register_hparams
def img2img_transformer2d_q1():
hparams = img2img_transformer2d_base()
hparams.batch_size = 2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.query_shape = (16, 16)
hparams.memory_flange = (16, 64)
return hparams
@registry.register_hparams
def img2img_transformer2d_q2():
hparams = img2img_transformer2d_q1()
hparams.batch_size = 2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.query_shape = (16, 16)
hparams.memory_flange = (16, 32)
return hparams
@registry.register_hparams
def img2img_transformer2d_q3():
"""Current best hparams for local 2d."""
hparams = img2img_transformer2d_q1()
hparams.batch_size = 2
hparams.query_shape = (8, 16)
hparams.memory_flange = (8, 32)
return hparams
@registry.register_hparams
def img2img_transformer_base():
"""Base params for local1d attention."""
hparams = image_transformer2d_base()
# learning related flags
hparams.layer_preprocess_sequence = "n"
hparams.layer_postprocess_sequence = "da"
# This version seems to benefit from a higher learning rate.
hparams.learning_rate = 0.2
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate_warmup_steps = 12000
hparams.filter_size = 2048
hparams.num_encoder_layers = 4
hparams.num_decoder_layers = 8
hparams.block_length = 256
hparams.block_width = 256
hparams.dec_attention_type = cia.AttentionType.LOCAL_1D
hparams.block_raster_scan = False
return hparams
@registry.register_hparams
def img2img_transformer_b1():
hparams = img2img_transformer_base()
hparams.batch_size = 2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.block_length = 512
return hparams
@registry.register_hparams
def img2img_transformer_b2():
hparams = img2img_transformer_base()
hparams.batch_size = 2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.block_length = 256
return hparams
@registry.register_hparams
def img2img_transformer_b3():
"""Current best hparams for local 1d."""
hparams = img2img_transformer_base()
hparams.batch_size = 2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.block_length = 128
hparams.sampling_temp = 0.9
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs1():
hparams = img2img_transformer_b3()
hparams.block_size = 1
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs2():
hparams = img2img_transformer_b3()
hparams.block_size = 2
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs3():
hparams = img2img_transformer_b3()
hparams.block_size = 3
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs4():
hparams = img2img_transformer_b3()
hparams.block_size = 4
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs5():
hparams = img2img_transformer_b3()
hparams.block_size = 5
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs6():
hparams = img2img_transformer_b3()
hparams.block_size = 6
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs7():
hparams = img2img_transformer_b3()
hparams.block_size = 7
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs8():
hparams = img2img_transformer_b3()
hparams.block_size = 8
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs9():
hparams = img2img_transformer_b3()
hparams.block_size = 9
return hparams
@registry.register_hparams
def img2img_transformer_b3_bs10():
hparams = img2img_transformer_b3()
hparams.block_size = 10
return hparams
@registry.register_hparams
def img2img_transformer_dilated():
"""Try dilated."""
hparams = img2img_transformer_base()
hparams.add_hparam("num_memory_blocks", 1)
hparams.num_heads = 8
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.num_decoder_layers = 8
hparams.sampling_method = "random"
hparams.gap_sizes = [0, 16, 64, 0, 16, 64, 128, 0]
hparams.dec_attention_type = cia.AttentionType.DILATED
hparams.img_len = 64
hparams.block_length = 128
hparams.block_width = 128
return hparams
@registry.register_hparams
def imagetransformer2d_tiny():
hparams = imagetransformer2d_base()
hparams.num_decoder_layers = 2
hparams.hidden_size = 64
hparams.batch_size = 1
return hparams
def update_hparams_for_tpu(hparams):
hparams.use_pad_remover = False # where op not supported
hparams.optimizer = "true_adam"
hparams.batch_size = 4
@registry.register_hparams
def img2img_transformer_base_tpu():
"""Hparams for training img2img_transformer on tpu."""
hparams = img2img_transformer_base()
update_hparams_for_tpu(hparams)
hparams.batch_size = 2
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 8
hparams.num_encoder_layers = 4
hparams.shared_embedding_and_softmax_weights = False
return hparams
@registry.register_hparams
def img2img_transformer_tiny_tpu():
hparams = img2img_transformer_base_tpu()
hparams.num_hidden_layers = 2
hparams.hidden_size = 16
hparams.batch_size = 2
hparams.num_heads = 2
return hparams
@registry.register_hparams
def img2img_transformer2d_n3():
hparams = img2img_transformer2d_base()
hparams.batch_size = 1
hparams.num_encoder_layers = 4
hparams.num_decoder_layers = 12
hparams.query_shape = (16, 32)
hparams.memory_flange = (16, 16)
hparams.layer_prepostprocess_dropout = 0.0
return hparams
@registry.register_hparams
def img2img_transformer2d_n31():
"""Set of hyperparameters."""
hparams = img2img_transformer2d_base()
hparams.batch_size = 1
hparams.num_encoder_layers = 6
hparams.num_decoder_layers = 12
hparams.num_heads = 8
hparams.query_shape = (16, 32)
hparams.memory_flange = (16, 32)
return hparams
@registry.register_hparams
def img2img_transformer2d_n24():
"""Set of hyperparameters."""
hparams = img2img_transformer2d_base()
hparams.batch_size = 1
hparams.hidden_size = 1024
hparams.filter_size = 2048
hparams.layer_prepostprocess_dropout = 0.2
hparams.num_decoder_layers = 8
hparams.query_shape = (8, 16)
hparams.memory_flange = (8, 32)
return hparams
@registry.register_hparams
def img2img_transformer2d_n44():
hparams = img2img_transformer2d_base()
hparams.batch_size = 1
hparams.num_decoder_layers = 8
hparams.query_shape = (8, 16)
hparams.memory_flange = (8, 32)
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def img2img_transformer2d_n103():
"""Best config for img2img."""
hparams = img2img_transformer2d_base()
hparams.batch_size = 1
hparams.num_decoder_layers = 12
hparams.num_encoder_layers = 6
hparams.query_shape = (8, 32)
hparams.memory_flange = (8, 64)
hparams.layer_prepostprocess_dropout = 0.1
return hparams
@registry.register_hparams
def img2img_transformer2d_tiny():
"""Tiny params."""
hparams = img2img_transformer2d_base()
hparams.num_decoder_layers = 2
hparams.hidden_size = 128
hparams.batch_size = 4
hparams.max_length = 128
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.filter_size = 128
hparams.num_heads = 4
hparams.pos = "timing"
hparams.img_len = 32
return hparams
@registry.register_hparams
def img2img_transformer_tiny():
"""Tiny params."""
hparams = img2img_transformer2d_base()
hparams.num_hidden_layers = 2
hparams.hidden_size = 128
hparams.batch_size = 4
hparams.max_length = 128
hparams.attention_key_channels = hparams.attention_value_channels = 0
hparams.filter_size = 128
hparams.num_heads = 1
hparams.pos = "timing"
return hparams
================================================
FILE: tensor2tensor/models/image_transformer_2d_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.data_generators import celeba # pylint: disable=unused-import
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import image_transformer_2d
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class Img2imgTransformerTest(tf.test.TestCase):
def _test_img2img_transformer(self, net):
batch_size = 3
hparams = image_transformer_2d.img2img_transformer2d_tiny()
hparams.data_dir = ""
p_hparams = registry.problem("image_celeba").get_hparams(hparams)
inputs = np.random.randint(256, size=(batch_size, 4, 4, 3))
targets = np.random.randint(256, size=(batch_size, 8, 8, 3))
with self.test_session() as session:
features = {
"inputs": tf.constant(inputs, dtype=tf.int32),
"targets": tf.constant(targets, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
model = net(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (batch_size, 8, 8, 3, 256))
def testImg2imgTransformer(self):
self._test_img2img_transformer(image_transformer_2d.Img2imgTransformer)
class Imagetransformer2dTest(tf.test.TestCase):
def _test_imagetransformer_2d(self, net):
batch_size = 3
size = 7
vocab_size = 256
hparams = image_transformer_2d.imagetransformer2d_tiny()
p_hparams = problem_hparams.test_problem_hparams(vocab_size,
vocab_size,
hparams)
inputs = np.random.randint(
vocab_size, size=(batch_size, 1, 1, 1))
targets = np.random.randint(
vocab_size, size=(batch_size, size, size, 3))
with self.test_session() as session:
features = {
"inputs": tf.constant(inputs, dtype=tf.int32),
"targets": tf.constant(targets, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
model = net(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (batch_size, size, size, 3, vocab_size))
def testImagetransformer2d(self):
self._test_imagetransformer_2d(image_transformer_2d.Imagetransformer2d)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/image_transformer_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.layers import common_image_attention
from tensor2tensor.models import image_transformer
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class ImagetransformerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
("ImageTransformerCat",
image_transformer.Imagetransformer,
image_transformer.imagetransformer_tiny()),
("ImageTransformerDmol",
image_transformer.Imagetransformer,
image_transformer.imagetransformerpp_tiny()),
)
def testImagetransformer(self, net, hparams):
batch_size = 3
size = 7
vocab_size = 256
p_hparams = problem_hparams.test_problem_hparams(vocab_size,
vocab_size,
hparams)
inputs = np.random.randint(
vocab_size, size=(batch_size, 1, 1, 1))
targets = np.random.randint(
vocab_size, size=(batch_size, size, size, 3))
with self.test_session() as session:
features = {
"inputs": tf.constant(inputs, dtype=tf.int32),
"targets": tf.constant(targets, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
model = net(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
if hparams.likelihood == common_image_attention.DistributionType.CAT:
expected = (batch_size, size, size, 3, vocab_size)
else:
expected = (batch_size, size, size, hparams.num_mixtures * 10)
self.assertEqual(res.shape, expected)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/lstm.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RNN LSTM models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensor2tensor.layers import area_attention
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import contrib
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
def _dropout_lstm_cell(hparams, train):
return tf.nn.rnn_cell.DropoutWrapper(
tf.nn.rnn_cell.LSTMCell(hparams.hidden_size),
input_keep_prob=1.0 - hparams.dropout * tf.to_float(train))
def lstm(inputs, sequence_length, hparams, train, name, initial_state=None):
"""Adds a stack of LSTM layers on top of input.
Args:
inputs: The input `Tensor`, shaped `[batch_size, time_steps, hidden_size]`.
sequence_length: Lengths of the actual input sequence, excluding padding; a
`Tensor` shaped `[batch_size]`.
hparams: HParams; hyperparameters.
train: bool; `True` when constructing training graph to enable dropout.
name: string; Create variable names under this scope.
initial_state: tuple of `LSTMStateTuple`s; the initial state of each layer.
Returns:
A tuple (outputs, states), where:
outputs: The output `Tensor`, shaped `[batch_size, time_steps,
hidden_size]`.
states: A tuple of `LSTMStateTuple`s; the final state of each layer.
Bidirectional LSTM returns a concatenation of last forward and backward
state, reduced to the original dimensionality.
"""
layers = [_dropout_lstm_cell(hparams, train)
for _ in range(hparams.num_hidden_layers)]
with tf.variable_scope(name):
return tf.nn.dynamic_rnn(
tf.nn.rnn_cell.MultiRNNCell(layers),
inputs,
sequence_length,
initial_state=initial_state,
dtype=tf.float32,
time_major=False)
def lstm_attention_decoder(inputs, hparams, train, name, initial_state,
encoder_outputs, encoder_output_length,
decoder_input_length):
"""Run LSTM cell with attention on inputs of shape [batch x time x size].
Args:
inputs: The decoder input `Tensor`, shaped `[batch_size, decoder_steps,
hidden_size]`.
hparams: HParams; hyperparameters.
train: bool; `True` when constructing training graph to enable dropout.
name: string; Create variable names under this scope.
initial_state: Tuple of `LSTMStateTuple`s; the initial state of each layer.
encoder_outputs: Encoder outputs; a `Tensor` shaped `[batch_size,
encoder_steps, hidden_size]`.
encoder_output_length: Lengths of the actual encoder outputs, excluding
padding; a `Tensor` shaped `[batch_size]`.
decoder_input_length: Lengths of the actual decoder inputs, excluding
padding; a `Tensor` shaped `[batch_size]`.
Raises:
ValueError: If the hparams.attention_mechanism is anything other than
luong or bahdanau.
Returns:
The decoder output `Tensor`, shaped `[batch_size, decoder_steps,
hidden_size]`.
"""
layers = [_dropout_lstm_cell(hparams, train)
for _ in range(hparams.num_hidden_layers)]
if hparams.attention_mechanism == "luong":
attention_mechanism_class = contrib.seq2seq().LuongAttention
elif hparams.attention_mechanism == "bahdanau":
attention_mechanism_class = contrib.seq2seq().BahdanauAttention
else:
raise ValueError("Unknown hparams.attention_mechanism = %s, must be "
"luong or bahdanau." % hparams.attention_mechanism)
if hparams.get("max_area_width", 1) > 1:
def _area_key_value_fn(keys, values):
"""Custom fn for computing area keys and values."""
tf.logging.info("max_area_width=%d, area_key_mode=%s, area_value_mode=%s",
hparams.get("max_area_width", 1),
hparams.get("area_key_mode", "none"),
hparams.get("area_value_mode", "none"))
keys = area_attention.compute_area_key(
keys, max_area_width=hparams.get("max_area_width", 1),
mode=hparams.get("area_key_mode", "none"), name="decoder_encoder",
training=(hparams.mode == tf_estimator.ModeKeys.TRAIN))
if hparams.get("area_value_mode", "none") == "sum":
_, _, values, _, _ = area_attention.compute_area_features(
values, max_area_width=hparams.get("max_area_width", 1))
elif hparams.get("area_value_mode", "none") == "mean":
values, _, _, _, _ = area_attention.compute_area_features(
values, max_area_width=hparams.get("max_area_width", 1))
else:
raise ValueError(
"Unsupported area_value_mode: %s" % hparams.get(
"area_value_mode", "none"))
return keys, values
area_mask = area_attention.lengths_to_area_mask(
feature_length=encoder_output_length,
length=common_layers.shape_list(encoder_outputs)[1],
max_area_size=hparams.get("max_area_width", "1"))
def _area_prob_fn(score):
alignments = tf.nn.softmax(score)
alignments = tf.where(area_mask, alignments, tf.zeros_like(alignments))
alignments = tf.div(alignments, tf.reduce_sum(
alignments, axis=-1, keepdims=True))
return alignments
attention_mechanism = attention_mechanism_class(
hparams.hidden_size, encoder_outputs,
memory_sequence_length=None,
probability_fn=_area_prob_fn,
custom_key_value_fn=_area_key_value_fn)
else:
attention_mechanism = attention_mechanism_class(hparams.hidden_size,
encoder_outputs)
cell = contrib.seq2seq().AttentionWrapper(
tf.nn.rnn_cell.MultiRNNCell(layers),
[attention_mechanism] * hparams.num_heads,
attention_layer_size=[hparams.attention_layer_size] * hparams.num_heads,
output_attention=(hparams.output_attention == 1))
batch_size = common_layers.shape_list(inputs)[0]
initial_state = cell.zero_state(batch_size, tf.float32).clone(
cell_state=initial_state)
with tf.variable_scope(name):
output, _ = tf.nn.dynamic_rnn(
cell,
inputs,
decoder_input_length,
initial_state=initial_state,
dtype=tf.float32,
time_major=False)
# output is [batch_size, decoder_steps, attention_size], where
# attention_size is either hparams.hidden_size (when
# hparams.output_attention is 0) or hparams.attention_layer_size (when
# hparams.output_attention is 1) times the number of attention heads.
#
# For multi-head attention project output back to hidden size.
if hparams.output_attention == 1 and hparams.num_heads > 1:
output = tf.layers.dense(output, hparams.hidden_size)
return output
def lstm_seq2seq_internal(inputs, targets, hparams, train):
"""The basic LSTM seq2seq model, main step used for training."""
with tf.variable_scope("lstm_seq2seq"):
if inputs is not None:
inputs_length = common_layers.length_from_embedding(inputs)
# Flatten inputs.
inputs = common_layers.flatten4d3d(inputs)
# LSTM encoder.
inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
_, final_encoder_state = lstm(inputs, inputs_length, hparams, train,
"encoder")
else:
final_encoder_state = None
# LSTM decoder.
shifted_targets = common_layers.shift_right(targets)
# Add 1 to account for the padding added to the left from shift_right
targets_length = common_layers.length_from_embedding(shifted_targets) + 1
decoder_outputs, _ = lstm(
common_layers.flatten4d3d(shifted_targets),
targets_length,
hparams,
train,
"decoder",
initial_state=final_encoder_state)
return tf.expand_dims(decoder_outputs, axis=2)
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train,
inputs_length, targets_length):
"""LSTM seq2seq model with attention, main step used for training."""
with tf.variable_scope("lstm_seq2seq_attention"):
# Flatten inputs.
inputs = common_layers.flatten4d3d(inputs)
# LSTM encoder.
inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
encoder_outputs, final_encoder_state = lstm(
inputs, inputs_length, hparams, train, "encoder")
# LSTM decoder with attention.
shifted_targets = common_layers.shift_right(targets)
# Add 1 to account for the padding added to the left from shift_right
targets_length = targets_length + 1
decoder_outputs = lstm_attention_decoder(
common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder",
final_encoder_state, encoder_outputs, inputs_length, targets_length)
return tf.expand_dims(decoder_outputs, axis=2)
def lstm_bid_encoder(inputs, sequence_length, hparams, train, name):
"""Bidirectional LSTM for encoding inputs that are [batch x time x size]."""
with tf.variable_scope(name):
cell_fw = tf.nn.rnn_cell.MultiRNNCell(
[_dropout_lstm_cell(hparams, train)
for _ in range(hparams.num_hidden_layers)])
cell_bw = tf.nn.rnn_cell.MultiRNNCell(
[_dropout_lstm_cell(hparams, train)
for _ in range(hparams.num_hidden_layers)])
((encoder_fw_outputs, encoder_bw_outputs),
(encoder_fw_state, encoder_bw_state)) = tf.nn.bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
inputs,
sequence_length,
dtype=tf.float32,
time_major=False)
encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
encoder_states = []
for i in range(hparams.num_hidden_layers):
if isinstance(encoder_fw_state[i], tf.nn.rnn_cell.LSTMStateTuple):
encoder_state_c = tf.concat(
values=(encoder_fw_state[i].c, encoder_bw_state[i].c),
axis=1,
name="encoder_fw_state_c")
encoder_state_h = tf.concat(
values=(encoder_fw_state[i].h, encoder_bw_state[i].h),
axis=1,
name="encoder_fw_state_h")
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(
c=encoder_state_c, h=encoder_state_h)
elif isinstance(encoder_fw_state[i], tf.Tensor):
encoder_state = tf.concat(
values=(encoder_fw_state[i], encoder_bw_state[i]),
axis=1,
name="bidirectional_concat")
encoder_states.append(encoder_state)
encoder_states = tuple(encoder_states)
return encoder_outputs, encoder_states
def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train):
"""The basic LSTM seq2seq model with bidirectional encoder."""
with tf.variable_scope("lstm_seq2seq_bid_encoder"):
if inputs is not None:
inputs_length = common_layers.length_from_embedding(inputs)
# Flatten inputs.
inputs = common_layers.flatten4d3d(inputs)
# LSTM encoder.
_, final_encoder_state = lstm_bid_encoder(
inputs, inputs_length, hparams, train, "encoder")
else:
inputs_length = None
final_encoder_state = None
# LSTM decoder.
shifted_targets = common_layers.shift_right(targets)
# Add 1 to account for the padding added to the left from shift_right
targets_length = common_layers.length_from_embedding(shifted_targets) + 1
hparams_decoder = copy.copy(hparams)
hparams_decoder.hidden_size = 2 * hparams.hidden_size
decoder_outputs, _ = lstm(
common_layers.flatten4d3d(shifted_targets),
targets_length,
hparams_decoder,
train,
"decoder",
initial_state=final_encoder_state)
return tf.expand_dims(decoder_outputs, axis=2)
def lstm_seq2seq_internal_attention_bid_encoder(inputs, targets, hparams,
train):
"""LSTM seq2seq model with attention, main step used for training."""
with tf.variable_scope("lstm_seq2seq_attention_bid_encoder"):
inputs_length = common_layers.length_from_embedding(inputs)
# Flatten inputs.
inputs = common_layers.flatten4d3d(inputs)
# LSTM encoder.
encoder_outputs, final_encoder_state = lstm_bid_encoder(
inputs, inputs_length, hparams, train, "encoder")
# LSTM decoder with attention
shifted_targets = common_layers.shift_right(targets)
# Add 1 to account for the padding added to the left from shift_right
targets_length = common_layers.length_from_embedding(shifted_targets) + 1
hparams_decoder = copy.copy(hparams)
hparams_decoder.hidden_size = 2 * hparams.hidden_size
decoder_outputs = lstm_attention_decoder(
common_layers.flatten4d3d(shifted_targets), hparams_decoder, train,
"decoder", final_encoder_state, encoder_outputs,
inputs_length, targets_length)
return tf.expand_dims(decoder_outputs, axis=2)
@registry.register_model
class LSTMEncoder(t2t_model.T2TModel):
"""LSTM encoder only."""
def body(self, features):
if self._hparams.initializer == "orthogonal":
raise ValueError("LSTM models fail with orthogonal initializer.")
train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN
inputs = features.get("inputs")
inputs_length = common_layers.length_from_embedding(inputs)
# Flatten inputs.
inputs = common_layers.flatten4d3d(inputs)
# LSTM encoder.
inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
encoder_output, _ = lstm(inputs, inputs_length, self._hparams, train,
"encoder")
return tf.expand_dims(encoder_output, axis=2)
@registry.register_model
class LSTMSeq2seq(t2t_model.T2TModel):
def body(self, features):
# TODO(lukaszkaiser): investigate this issue and repair.
if self._hparams.initializer == "orthogonal":
raise ValueError("LSTM models fail with orthogonal initializer.")
train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN
return lstm_seq2seq_internal(features.get("inputs"), features["targets"],
self._hparams, train)
@registry.register_model
class LSTMSeq2seqAttention(t2t_model.T2TModel):
"""Seq to seq LSTM with attention."""
def body(self, features):
# TODO(lukaszkaiser): investigate this issue and repair.
if self._hparams.initializer == "orthogonal":
raise ValueError("LSTM models fail with orthogonal initializer.")
train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN
# This is a temporary fix for varying-length sequences within in a batch.
# A more complete fix should pass a length tensor from outside so that
# all the lstm variants can use it.
input_shape = common_layers.shape_list(features["inputs_raw"])
flat_input = tf.reshape(features["inputs_raw"],
[input_shape[0], input_shape[1]])
inputs_length = tf.reduce_sum(tf.minimum(flat_input, 1), -1)
target_shape = common_layers.shape_list(features["targets_raw"])
flat_target = tf.reshape(features["targets_raw"],
[target_shape[0], target_shape[1]])
targets_length = tf.reduce_sum(tf.minimum(flat_target, 1), -1)
tf.logging.info(self._hparams)
return lstm_seq2seq_internal_attention(
features["inputs"], features["targets"], self._hparams, train,
inputs_length, targets_length)
@registry.register_model
class LSTMSeq2seqBidirectionalEncoder(t2t_model.T2TModel):
def body(self, features):
# TODO(lukaszkaiser): investigate this issue and repair.
if self._hparams.initializer == "orthogonal":
raise ValueError("LSTM models fail with orthogonal initializer.")
train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN
return lstm_seq2seq_internal_bid_encoder(
features.get("inputs"), features["targets"], self._hparams, train)
@registry.register_model
class LSTMSeq2seqAttentionBidirectionalEncoder(t2t_model.T2TModel):
def body(self, features):
# TODO(lukaszkaiser): investigate this issue and repair.
if self._hparams.initializer == "orthogonal":
raise ValueError("LSTM models fail with orthogonal initializer.")
train = self._hparams.mode == tf_estimator.ModeKeys.TRAIN
return lstm_seq2seq_internal_attention_bid_encoder(
features.get("inputs"), features["targets"], self._hparams, train)
@registry.register_hparams
def lstm_seq2seq():
"""hparams for LSTM."""
hparams = common_hparams.basic_params1()
hparams.daisy_chain_variables = False
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 2
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
hparams.weight_decay = 0.0
return hparams
def lstm_attention_base():
"""Base attention params."""
hparams = lstm_seq2seq()
hparams.add_hparam("attention_layer_size", hparams.hidden_size)
hparams.add_hparam("output_attention", True)
hparams.add_hparam("num_heads", 1)
return hparams
@registry.register_hparams
def lstm_bahdanau_attention():
"""Hparams for LSTM with bahdanau attention."""
hparams = lstm_attention_base()
hparams.add_hparam("attention_mechanism", "bahdanau")
return hparams
@registry.register_hparams
def lstm_luong_attention():
"""Hparams for LSTM with luong attention."""
hparams = lstm_attention_base()
hparams.add_hparam("attention_mechanism", "luong")
return hparams
@registry.register_hparams
def lstm_attention():
"""For backwards compatibility, defaults to bahdanau."""
return lstm_bahdanau_attention()
@registry.register_hparams
def lstm_bahdanau_attention_multi():
"""Multi-head Bahdanau attention."""
hparams = lstm_bahdanau_attention()
hparams.num_heads = 4
return hparams
@registry.register_hparams
def lstm_luong_attention_multi():
"""Multi-head Luong attention."""
hparams = lstm_luong_attention()
hparams.num_heads = 4
return hparams
@registry.register_hparams
def lstm_asr_v1():
"""Basic LSTM Params."""
hparams = lstm_bahdanau_attention()
hparams.num_hidden_layers = 2
hparams.hidden_size = 256
hparams.batch_size = 36
hparams.max_input_seq_length = 600000
hparams.max_target_seq_length = 350
hparams.max_length = hparams.max_input_seq_length
hparams.min_length_bucket = hparams.max_input_seq_length // 2
hparams.learning_rate = 0.05
return hparams
@registry.register_hparams
def lstm_area_attention_base():
"""Hparams for LSTM with area attention."""
hparams = lstm_luong_attention()
hparams.batch_size = 16384
hparams.num_hidden_layers = 2
hparams.hidden_size = 1024
hparams.num_heads = 4
hparams.dropout = 0.2
hparams.learning_rate = 0.1
hparams.max_area_width = 2
hparams.area_key_mode = "mean"
hparams.area_value_mode = "sum"
return hparams
@registry.register_hparams
def lstm_area_attention_enfr():
"""Hparams for LSTM with area attention."""
hparams = lstm_area_attention_base()
hparams.dropout = 0.1
return hparams
@registry.register_hparams
def lstm_area_attention_char():
"""Hparams for LSTM with area attention."""
hparams = lstm_area_attention_base()
hparams.batch_size = 20480
return hparams
@registry.register_hparams
def lstm_area_attention_char_enfr():
"""Hparams for LSTM with area attention."""
hparams = lstm_area_attention_char()
hparams.dropout = 0.1
return hparams
================================================
FILE: tensor2tensor/models/lstm_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LSTMSeq2Seq models tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import lstm
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
class LSTMTest(tf.test.TestCase):
def testLSTMSeq2Seq(self):
vocab_size = 9
x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
hparams = lstm.lstm_seq2seq()
p_hparams = problem_hparams.test_problem_hparams(vocab_size,
vocab_size,
hparams)
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = lstm.LSTMSeq2seq(hparams, tf_estimator.ModeKeys.TRAIN,
p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
def testLSTMSeq2SeqAttention(self):
vocab_size = 9
x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
hparams = lstm.lstm_attention()
p_hparams = problem_hparams.test_problem_hparams(vocab_size,
vocab_size,
hparams)
x = tf.constant(x, dtype=tf.int32)
x = tf.placeholder_with_default(x, shape=[None, None, 1, 1])
with self.test_session() as session:
features = {
"inputs": x,
"targets": tf.constant(y, dtype=tf.int32),
}
model = lstm.LSTMSeq2seqAttention(
hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
def testLSTMSeq2seqBidirectionalEncoder(self):
vocab_size = 9
x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
hparams = lstm.lstm_seq2seq()
p_hparams = problem_hparams.test_problem_hparams(vocab_size,
vocab_size,
hparams)
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = lstm.LSTMSeq2seqBidirectionalEncoder(
hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
def testLSTMSeq2seqAttentionBidirectionalEncoder(self):
vocab_size = 9
x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
hparams = lstm.lstm_attention()
p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size)
x = tf.constant(x, dtype=tf.int32)
x = tf.placeholder_with_default(x, shape=[None, None, 1, 1])
with self.test_session() as session:
features = {
"inputs": x,
"targets": tf.constant(y, dtype=tf.int32),
}
model = lstm.LSTMSeq2seqAttentionBidirectionalEncoder(
hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/mtf_image_transformer.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image Transformer model with model and data parallelism using MTF.
Integration of Mesh tensorflow with Image Transformer to do model parallelism.
Currently, this supports unconditional image generation. Specify a particular
architecture layout in the hparams that specifies how different dimensions are
split or replicated along the mesh dimensions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import mesh_tensorflow as mtf
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import mtf_model
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
@registry.register_model
class MtfImageTransformer(mtf_model.MtfModel):
"""Image Transformer in mesh_tensorflow."""
@property
def inputs_vocab_dim(self):
assert self.has_input
return mtf.Dimension("inputs_vocab", self._hparams.num_classes)
@property
def targets_vocab_dim(self):
vocab_size = self._problem_hparams.vocab_size["targets"]
if hasattr(self._hparams, "vocab_divisor"):
vocab_size += (-vocab_size) % self._hparams.vocab_divisor
return mtf.Dimension("vocab", vocab_size)
@property
def outputs_vocab_dim(self):
return mtf.Dimension("output_vocab", 256)
@property
def pos_dim(self):
return mtf.Dimension("pos", self._hparams.img_len)
@property
def rows_dim(self):
return mtf.Dimension("rows", self._hparams.img_len)
@property
def cols_dim(self):
return mtf.Dimension(
"cols", self._hparams.img_len*self._hparams.num_channels)
@property
def orig_cols_dim(self):
return mtf.Dimension("orig_cols", self._hparams.img_len)
@property
def channels_dim(self):
return mtf.Dimension("channels", self._hparams.num_channels)
@property
def model_dim(self):
return mtf.Dimension("d_model", self._hparams.hidden_size)
@property
def max_length_dim(self):
return mtf.Dimension(
"max_length",
self._hparams.img_len*self._hparams.img_len*self._hparams.num_channels)
@property
def length_dim(self):
return mtf.Dimension(
"length",
self._hparams.img_len*self._hparams.img_len*self._hparams.num_channels)
@property
def heads_dim(self):
return mtf.Dimension("heads", self._hparams.num_heads)
@property
def kv_dim(self):
return mtf.Dimension("d_kv", self._hparams.d_kv)
@property
def feedforward_dim(self):
return mtf.Dimension("d_ff", self._hparams.d_ff)
@property
def activation_type(self):
hparams = self._hparams
if hparams.activation_dtype == "float32":
activation_dtype = tf.float32
elif hparams.activation_dtype == "float16":
activation_dtype = tf.float16
elif hparams.activation_dtype == "bfloat16":
activation_dtype = tf.bfloat16
else:
raise ValueError(
"unknown hparams.activation_dtype %s" % hparams.activation_dtype)
return activation_dtype
def create_positional_emb_2d(self, targets):
"""Learned 2d positional embedding for images."""
mesh = targets.mesh
positional_emb_rows_var = mtf.get_variable(
mesh, "positional_emb_rows",
mtf.Shape([self.pos_dim, self.model_dim]),
initializer=tf.random_normal_initializer(),
activation_dtype=self.activation_type)
positional_emb_cols_var = mtf.get_variable(
mesh, "positional_emb_cols",
mtf.Shape([self.pos_dim, self.model_dim]),
initializer=tf.random_normal_initializer(),
activation_dtype=self.activation_type)
targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32)
targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32)
position_x = mtf.broadcast(
mtf.gather(positional_emb_rows_var, targets_position_x,
self.pos_dim),
mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
position_y = mtf.broadcast(
mtf.gather(positional_emb_cols_var, targets_position_y,
self.pos_dim),
mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
return position_x + position_y
def mtf_model_fn(self, features, mesh):
features = copy.copy(features)
tf.logging.info("features = %s" % features)
hparams = self._hparams
activation_dtype = self.activation_type
# We assume fixed vocab size for targets
targets = tf.to_int32(features["targets"])
# Image preprocessing, reshape into a 1D sequence and shift right.
length = hparams.img_len*hparams.img_len*hparams.num_channels
targets = tf.reshape(targets, [hparams.batch_size, length])
shifted_targets = common_layers.shift_right_2d(targets)
# Declare all the dimensions
batch_dim = mtf.Dimension("batch", hparams.batch_size)
def import_to_batch_by_length(x, name):
return mtf.import_tf_tensor(
mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name)
targets = import_to_batch_by_length(targets, "targets")
shifted_targets = import_to_batch_by_length(
shifted_targets, "shifted_targets")
extra_losses = []
# Create targets content and position embeddings.
# Create embedding var for targets and positions and do a gather.
targets_embedding_var = mtf.get_variable(
mesh, "targets_embedding",
mtf.Shape([self.targets_vocab_dim, self.model_dim]),
initializer=tf.random_normal_initializer(),
activation_dtype=activation_dtype)
x = mtf.gather(targets_embedding_var,
shifted_targets, self.targets_vocab_dim)
# Add positional embeddings
x += mtf.reshape(self.create_positional_emb_2d(targets),
[self.length_dim, self.model_dim])
# If conditional and input is given, add the input embedding to the target.
# TODO(nikip): Verify conditional.
if self.has_input and not hparams.unconditional:
inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
inputs = import_to_batch_by_length(inputs, "inputs")
# Input embeddings
inputs_embedding_var = mtf.layers.embedding(
mesh, "input_embedding",
mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
activation_dtype=activation_dtype)
inputs_emb = mtf.gather(
inputs_embedding_var, inputs, self.inputs_vocab_dim)
x += inputs_emb
# Image Transformer Decoder
# [ self attention - ffn - residual + dropout] x n
if hparams.attention_type == "local1d_spatial":
decoder_output = local_attention1d_spatial_decoder(
x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
elif hparams.attention_type == "local2d_spatial":
decoder_output = local_attention2d_spatial_decoder(
x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
elif hparams.attention_type == "local1d":
decoder_output = local_attention1d_masked_decoder(
x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
else:
raise ValueError("Invalid attention type.")
# Calculate the logits and loss.
logits = mtf.layers.dense(
decoder_output, self.outputs_vocab_dim, name="logits")
# Need a reshape for logits
logits = mtf.reshape(
logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
soft_targets = mtf.one_hot(
targets, self.outputs_vocab_dim, dtype=activation_dtype)
loss = mtf.layers.softmax_cross_entropy_with_logits(
logits, soft_targets, self.outputs_vocab_dim)
loss = mtf.reduce_mean(loss)
for l in extra_losses:
loss += l
# Reshape logits to original target shape.
logits = mtf.reshape(
logits,
mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim,
self.channels_dim, self.outputs_vocab_dim]))
return logits, loss
def layer_prepostprocess_dropout(x, hparams):
batch_dim = x.shape.dims[0]
model_dim = x.shape.dims[-1]
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
return mtf.dropout(
x, is_training,
keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape([batch_dim, model_dim]))
def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
feedforward_dim, hparams):
"""Image Transformer decoder with local1D spatial layers."""
batch_dim, length_dim, model_dim = x.shape.dims
blocks_w_dim = mtf.Dimension("blocksw", hparams.block_length)
num_w_blocks_dim = mtf.Dimension("num_wblocks",
length_dim.size // blocks_w_dim.size)
x = mtf.reshape(
x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim]))
# [ self attention - ffn - residual + dropout] x n
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
for layer in range(hparams.num_decoder_layers):
layer_name = "decoder_layer_%d" % layer
with tf.variable_scope(layer_name):
# Self attention layer
x += layer_prepostprocess_dropout(
mtf.layers.local_self_attention_spatial_blocks(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
kv_dim,
heads_dim,
is_training,
memory_w_dim=blocks_w_dim,
mask_right=True,
name="self_att"), hparams)
# ffn layer
x += layer_prepostprocess_dropout(
mtf.layers.dense_relu_dense(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
feedforward_dim,
is_training,
hparams.dropout,
dropout_broadcast_dims=[length_dim]), hparams)
output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
return output
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
feedforward_dim, hparams):
"""Image Transformer decoder with local2D spatial layers."""
batch_dim, length_dim, model_dim = x.shape.dims
blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height)
blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width)
num_h_blocks_dim = mtf.Dimension("num_h_blocks",
hparams.img_len // hparams.block_height)
num_w_blocks_dim = mtf.Dimension(
"num_w_blocks",
hparams.img_len * hparams.num_channels // hparams.block_width)
x = mtf.transpose(
mtf.reshape(
x,
mtf.Shape([
batch_dim, num_h_blocks_dim, blocks_h_dim,
num_w_blocks_dim, blocks_w_dim, model_dim
])),
mtf.Shape([
batch_dim, num_h_blocks_dim, num_w_blocks_dim,
blocks_h_dim, blocks_w_dim, model_dim
]))
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
# Image Transformer Decoder
# [ self attention - ffn - residual + dropout] x n
for layer in range(hparams.num_decoder_layers):
layer_name = "decoder_layer_%d" % layer
with tf.variable_scope(layer_name):
# Self attention layer
x += layer_prepostprocess_dropout(
mtf.layers.local_2d_self_attention_spatial_blocks(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
kv_dim,
heads_dim,
is_training,
memory_h_dim=num_h_blocks_dim,
memory_w_dim=num_w_blocks_dim,
name="self_att"), hparams)
# ffn layer
x += layer_prepostprocess_dropout(
mtf.layers.dense_relu_dense(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
feedforward_dim,
hparams.dropout,
dropout_broadcast_dims=[length_dim]), hparams)
output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
return output
def local_attention1d_masked_decoder(x, kv_dim, heads_dim,
feedforward_dim, hparams):
"""Image Transformer decoder with local1D masked layers."""
print(x)
_, length_dim, model_dim = x.shape.dims
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
for layer in range(hparams.num_decoder_layers):
layer_name = "decoder_layer_%d" % layer
with tf.variable_scope(layer_name):
# Self attention layer
length_per_split = mtf.tensor_dim_to_size_per_split(
hparams.layout, hparams.mesh_shape, length_dim)
x += layer_prepostprocess_dropout(
mtf.layers.masked_local_attention_1d(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
kv_dim,
heads_dim,
is_training,
window_size=hparams.block_length,
length_per_split=length_per_split,
name="self_att"), hparams)
# ffn layer
x += layer_prepostprocess_dropout(
mtf.layers.dense_relu_dense(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
feedforward_dim,
hparams.dropout,
dropout_broadcast_dims=[length_dim]), hparams)
output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
return output
@registry.register_hparams
def mtf_image_transformer_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.no_data_parallelism = True
hparams.use_fixed_batch_size = True
hparams.batch_size = 1
hparams.max_length = 3072
hparams.hidden_size = 256
hparams.label_smoothing = 0.0
# 8-way model-parallelism
hparams.add_hparam("mesh_shape", "batch:8")
hparams.add_hparam("layout", "batch:batch")
hparams.add_hparam("mtf_mode", True)
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("filter_size", 1024)
hparams.add_hparam("num_encoder_layers", 0)
hparams.add_hparam("num_decoder_layers", 6)
hparams.add_hparam("attention_key_size", 256)
hparams.add_hparam("attention_value_size", 256)
# Share weights between input and target embeddings
hparams.shared_embedding = True
# mixture of experts hparams
hparams.add_hparam("ffn_layer", "dense_relu_dense")
hparams.add_hparam("moe_overhead_train", 1.0)
hparams.add_hparam("moe_overhead_eval", 2.0)
hparams.moe_num_experts = 16
hparams.moe_loss_coef = 1e-3
hparams.shared_embedding_and_softmax_weights = True
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay"
hparams.learning_rate_warmup_steps = 10000
hparams.add_hparam("d_kv", 64)
hparams.add_hparam("d_ff", 2048)
# Image related hparams
hparams.add_hparam("img_len", 32)
hparams.add_hparam("num_channels", 3)
hparams.add_hparam("unconditional", True)
# Local Attention related params
hparams.add_hparam("block_length", 128)
hparams.add_hparam("block_height", 16)
hparams.add_hparam("block_width", 16)
hparams.add_hparam("attention_type", "local1d")
return hparams
@registry.register_hparams
def mtf_image_transformer_tiny():
"""Catch bugs locally..."""
hparams = mtf_image_transformer_base()
hparams.hidden_size = 128
hparams.d_ff = 256
hparams.batch_size = 4
hparams.num_encoder_layers = 1
hparams.num_decoder_layers = 4
hparams.num_heads = 4
hparams.attention_key_size = 128
hparams.attention_value_size = 128
hparams.block_length = 32
# data parallelism and model-parallelism
hparams.mesh_shape = "batch:2"
hparams.layout = "batch:batch"
return hparams
@registry.register_hparams
def mtf_image_transformer_single():
"""Small single parameters."""
hparams = mtf_image_transformer_tiny()
hparams.mesh_shape = ""
hparams.layout = ""
hparams.hidden_size = 32
hparams.filter_size = 32
hparams.batch_size = 1
hparams.num_encoder_layers = 1
hparams.num_decoder_layers = 1
hparams.num_heads = 2
hparams.attention_key_size = 32
hparams.attention_value_size = 32
hparams.block_length = 16
return hparams
@registry.register_hparams
def mtf_image_transformer_base_single():
"""Small single parameters."""
hparams = mtf_image_transformer_base()
hparams.num_decoder_layers = 6
hparams.filter_size = 256
hparams.block_length = 128
hparams.mesh_shape = ""
hparams.layout = ""
return hparams
@registry.register_hparams
def mtf_image_transformer_tiny_spatial1d():
"""Small single parameters."""
hparams = mtf_image_transformer_tiny()
hparams.num_decoder_layers = 6
hparams.filter_size = 128
hparams.block_height = 8
hparams.block_width = 8
hparams.attention_type = "local1d_spatial"
hparams.mesh_shape = ""
hparams.layout = ""
return hparams
@registry.register_hparams
def mtf_image_transformer_tiny_spatial2d():
"""Small single parameters."""
hparams = mtf_image_transformer_tiny()
hparams.num_decoder_layers = 6
hparams.filter_size = 128
hparams.block_height = 8
hparams.block_width = 8
hparams.attention_type = "local2d_spatial"
hparams.mesh_shape = "b1:2,b2:2"
hparams.layout = "num_h_blocks:b1,num_wblocks:b2"
return hparams
@registry.register_hparams
def mtf_image_transformer_base_cifar():
"""Data parallel CIFAR parameters."""
hparams = mtf_image_transformer_base()
hparams.mesh_shape = "batch:8"
hparams.layout = "batch:batch"
hparams.learning_rate_decay_steps = 13600 # one epoch
hparams.batch_size = 32
hparams.num_heads = 4
hparams.num_decoder_layers = 12
hparams.block_length = 256
hparams.hidden_size = 512
hparams.d_ff = 2048
hparams.learning_rate = 0.5
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
hparams.unconditional = True
return hparams
@registry.register_hparams
def mtf_image_transformer_cifar_4x():
"""Data parallel CIFAR parameters."""
hparams = mtf_image_transformer_base_cifar()
hparams.mesh_shape = "batch:32"
hparams.layout = "batch:batch"
hparams.batch_size = 128
return hparams
@registry.register_hparams
def mtf_image_transformer_cifar_mp_4x():
"""Data parallel CIFAR parameters."""
hparams = mtf_image_transformer_base_cifar()
hparams.mesh_shape = "model:4;batch:8"
hparams.layout = "batch:batch;d_ff:model;heads:model"
hparams.batch_size = 32
hparams.num_heads = 8
hparams.d_ff = 8192
return hparams
@registry.register_hparams
def mtf_image_transformer_base_imagenet():
"""Data parallel CIFAR parameters."""
hparams = mtf_image_transformer_base_cifar()
hparams.mesh_shape = "batch:32"
hparams.layout = "batch:batch"
hparams.batch_size = 128
hparams.d_ff = 2048
hparams.hidden_size = 512
hparams.num_decoder_layers = 12
hparams.learning_rate = 0.5
hparams.learning_rate_warmup_steps = 31250
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.1
hparams.unconditional = True
return hparams
@registry.register_hparams
def mtf_image_transformer_base_imagenet_mp():
"""Model parallel ImageNet parameters."""
hparams = mtf_image_transformer_base_imagenet()
hparams.mesh_shape = "model:4;batch:8"
hparams.layout = "batch:batch;d_ff:model;heads:model"
hparams.batch_size = 32
hparams.num_heads = 8
hparams.d_ff = 8192
hparams.learning_rate_warmup_steps = 31250
hparams.unconditional = True
return hparams
@registry.register_hparams
def mtf_image_transformer_base_imagenet_mp128():
"""Model parallel ImageNet parameters."""
hparams = mtf_image_transformer_base_imagenet()
hparams.mesh_shape = "model:8;batch:4"
hparams.layout = "batch:batch;d_ff:model;heads:model"
hparams.batch_size = 8
hparams.img_len = 128
hparams.block_length = 128
hparams.num_heads = 8
hparams.num_decoder_layers = 4
hparams.d_ff = 4096
hparams.learning_rate_warmup_steps = 31250
hparams.unconditional = True
hparams.max_length = 256*256*3
return hparams
@registry.register_hparams
def mtf_image_transformer_base_imagenet_mp_sp():
"""Model parallel ImageNet parameters."""
hparams = mtf_image_transformer_base_imagenet_mp128()
hparams.mesh_shape = "model:8;batch:4"
hparams.layout = "batch:batch;d_ff:model;num_wblocks:model"
hparams.batch_size = 8
hparams.img_len = 128
hparams.block_length = 128
hparams.attention_type = "local1d_spatial"
return hparams
@registry.register_hparams
def mtf_image_transformer_base_imagenet_mp64():
"""Model parallel ImageNet parameters."""
hparams = mtf_image_transformer_base_imagenet()
hparams.mesh_shape = "model:8;batch:4"
hparams.layout = "batch:batch;d_ff:model;heads:model"
hparams.batch_size = 8
hparams.img_len = 64
hparams.num_decoder_layers = 8
return hparams
@registry.register_hparams
def mtf_image_transformer_tiny_8gpu():
hparams = mtf_image_transformer_tiny()
hparams.mesh_shape = "all:8"
hparams.layout = "vocab:all;filter_size:all;heads:all"
return hparams
@registry.register_hparams
def mtf_image_transformer_length_sharded():
hparams = mtf_image_transformer_tiny()
hparams.mesh_shape = "all:2"
hparams.layout = "length:all"
return hparams
================================================
FILE: tensor2tensor/models/mtf_image_transformer_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Image Transformer on Mesh TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mesh_tensorflow as mtf
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import mtf_image_transformer
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
# Constants shared between all functions.
BATCH_SIZE = 8
IMG_LENGTH = 8
VOCAB_SIZE = 256
def get_model(hparams=None,
mode=tf_estimator.ModeKeys.TRAIN,
model_cls=mtf_image_transformer.MtfImageTransformer):
if hparams is None:
hparams = mtf_image_transformer.mtf_image_transformer_single()
hparams.max_length = IMG_LENGTH*IMG_LENGTH
hparams.batch_size = BATCH_SIZE
hparams.img_len = IMG_LENGTH
hparams.num_channels = 1
p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE,
VOCAB_SIZE,
hparams)
del p_hparams.modality["inputs"]
hparams.problem_hparams = p_hparams
targets = np.random.randint(
VOCAB_SIZE, size=(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, 1, 1))
features = {
"targets": tf.constant(targets, dtype=tf.int32, name="targets"),
}
return model_cls(hparams, mode, p_hparams), features, hparams
def get_placement_mesh(hparams):
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
mesh_devices = [""] * mesh_shape.size
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, hparams.layout, mesh_devices)
return mesh, mesh_impl
class MtfImageTransformerTest(tf.test.TestCase):
def testMtfImageTransformer(self):
hparams = mtf_image_transformer.mtf_image_transformer_single()
# need to know layout ahead of time for local attention.
hparams.mesh_shape = ""
hparams.layout = ""
model, features, hparams = get_model(hparams)
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape,
(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH,
hparams.num_channels, VOCAB_SIZE))
def testMtfImageTransformerDataParallel(self):
hparams = mtf_image_transformer.mtf_image_transformer_single()
# need to know layout ahead of time for local attention.
hparams.mesh_shape = "all:2"
hparams.layout = "batch:all"
model, features, hparams = get_model(hparams)
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape,
(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH,
hparams.num_channels, VOCAB_SIZE))
def testMtfImageTransformerModelParallel(self):
hparams = mtf_image_transformer.mtf_image_transformer_single()
# need to know layout ahead of time for local attention.
hparams.mesh_shape = "all:2"
hparams.layout = "length:all"
model, features, hparams = get_model(hparams)
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(
res.shape,
(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, hparams.num_channels, VOCAB_SIZE))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/mtf_resnet.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet model with model and data parallelism using MTF.
Integration of Mesh tensorflow with ResNet to do model parallelism.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import mesh_tensorflow as mtf
from tensor2tensor.layers import common_hparams
from tensor2tensor.utils import mtf_model
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def batch_norm_relu(inputs, is_training, relu=True):
"""Block of batch norm and relu."""
inputs = mtf.layers.batch_norm(
inputs,
is_training,
BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
init_zero=(not relu))
if relu:
inputs = mtf.relu(inputs)
return inputs
def bottleneck_block(inputs,
filters,
is_training,
strides,
projection_shortcut=None,
row_blocks_dim=None,
col_blocks_dim=None):
"""Bottleneck block variant for residual networks with BN after convolutions.
Args:
inputs: a `mtf.Tensor` of shape
`[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
filters: `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
is_training: `bool` for whether the model is in training mode.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
projection_shortcut: `function` to use for projection shortcuts (typically
a 1x1 convolution to match the filter dimensions). If None, no
projection is used and the input is passed as unchanged through the
shortcut connection.
row_blocks_dim: a mtf.Dimension, row dimension which is
spatially partitioned along mesh axis
col_blocks_dim: a mtf.Dimension, row dimension which is
spatially partitioned along mesh axis
Returns:
The output `Tensor` of the block.
"""
shortcut = inputs
if projection_shortcut is not None:
filters_dim = mtf.Dimension("filtersp", filters)
shortcut = projection_shortcut(inputs, filters_dim)
# First conv block
inputs = mtf.layers.conv2d_with_blocks(
inputs,
mtf.Dimension("filters1", filters),
filter_size=[1, 1],
strides=[1, 1],
padding="SAME",
h_blocks_dim=None, w_blocks_dim=col_blocks_dim,
name="conv0")
# TODO(nikip): Add Dropout?
inputs = batch_norm_relu(inputs, is_training)
# Second conv block
inputs = mtf.layers.conv2d_with_blocks(
inputs,
mtf.Dimension("filters2", 4 * filters),
filter_size=[3, 3],
strides=[1, 1],
padding="SAME",
h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim,
name="conv1")
inputs = batch_norm_relu(inputs, is_training)
# Third wide conv filter block
inputs = mtf.layers.conv2d_with_blocks(
inputs,
mtf.Dimension("filters3", filters),
filter_size=[1, 1],
strides=strides,
padding="SAME",
h_blocks_dim=None, w_blocks_dim=col_blocks_dim,
name="conv2")
# TODO(nikip): Althought the original resnet code has this batch norm, in our
# setup this is causing no gradients to be passed. Investigate further.
# inputs = batch_norm_relu(inputs, is_training, relu=True)
# TODO(nikip): Maybe add residual with a projection?
return mtf.relu(
shortcut + mtf.rename_dimension(
inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
def block_layer(inputs,
filters,
blocks,
strides,
is_training,
name,
row_blocks_dim=None,
col_blocks_dim=None):
"""Creates one layer of blocks for the ResNet model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
blocks: `int` number of blocks contained in the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
is_training: `bool` for whether the model is training.
name: `str`name for the Tensor output of the block layer.
row_blocks_dim: a mtf.Dimension, row dimension which is
spatially partitioned along mesh axis
col_blocks_dim: a mtf.Dimension, row dimension which is
spatially partitioned along mesh axis
Returns:
The output `Tensor` of the block layer.
"""
with tf.variable_scope(name, default_name="block_layer"):
# Only the first block per block_layer uses projection_shortcut and strides
def projection_shortcut(inputs, output_dim):
"""Project identity branch."""
inputs = mtf.layers.conv2d_with_blocks(
inputs,
output_dim,
filter_size=[1, 1],
strides=strides,
padding="SAME",
h_blocks_dim=None, w_blocks_dim=col_blocks_dim,
name="shortcut0")
return batch_norm_relu(
inputs, is_training, relu=False)
inputs = bottleneck_block(
inputs,
filters,
is_training,
strides=strides,
projection_shortcut=projection_shortcut,
row_blocks_dim=row_blocks_dim,
col_blocks_dim=col_blocks_dim)
for i in range(1, blocks):
with tf.variable_scope("bottleneck_%d" % i):
inputs = bottleneck_block(
inputs,
filters,
is_training,
strides=[1, 1, 1, 1],
projection_shortcut=None,
row_blocks_dim=row_blocks_dim,
col_blocks_dim=col_blocks_dim)
return inputs
@registry.register_model
class MtfResNet(mtf_model.MtfModel):
"""ResNet in mesh_tensorflow."""
def set_activation_type(self):
hparams = self._hparams
if hparams.activation_dtype == "float32":
activation_dtype = tf.float32
elif hparams.activation_dtype == "float16":
activation_dtype = tf.float16
elif hparams.activation_dtype == "bfloat16":
activation_dtype = tf.bfloat16
else:
raise ValueError(
"unknown hparams.activation_dtype %s" % hparams.activation_dtype)
return activation_dtype
def mtf_model_fn(self, features, mesh):
features = copy.copy(features)
tf.logging.info("features = %s" % features)
hparams = self._hparams
activation_dtype = self.set_activation_type()
is_training = hparams.mode == tf_estimator.ModeKeys.TRAIN
# Declare all the dimensions
batch_dim = mtf.Dimension("batch", hparams.batch_size)
hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
filter_dim = mtf.Dimension("filters", hparams.filter_sizes[0])
rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
classes_dim = mtf.Dimension("classes", 10)
channels_dim = mtf.Dimension("channels", 3)
one_channel_dim = mtf.Dimension("one_channel", 1)
inputs = features["inputs"]
x = mtf.import_tf_tensor(
mesh, tf.reshape(inputs, [
hparams.batch_size,
hparams.row_blocks,
hparams.rows_size // hparams.row_blocks,
hparams.col_blocks,
hparams.num_channels*hparams.cols_size // hparams.col_blocks,
hparams.num_channels]),
mtf.Shape(
[batch_dim, row_blocks_dim, rows_dim,
col_blocks_dim, cols_dim, channels_dim]))
x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
rows_dim, cols_dim, channels_dim])
x = mtf.to_float(x)
x = mtf.layers.conv2d_with_blocks(
x,
filter_dim,
filter_size=[3, 3],
strides=[1, 1],
padding="SAME",
h_blocks_dim=None, w_blocks_dim=col_blocks_dim,
name="initial_filter")
x = batch_norm_relu(x, is_training)
# Conv blocks
# [block - strided block layer - strided block layer] x n
for layer in range(hparams.num_layers):
layer_name = "block_layer_%d" % layer
with tf.variable_scope(layer_name):
# Residual block layer
x = block_layer(
inputs=x,
filters=hparams.filter_sizes[0],
blocks=hparams.layer_sizes[0],
strides=[1, 1],
is_training=is_training,
name="block_layer1",
row_blocks_dim=None,
col_blocks_dim=None)
x = block_layer(
inputs=x,
filters=hparams.filter_sizes[1],
blocks=hparams.layer_sizes[1],
strides=[1, 1],
is_training=is_training,
name="block_layer2",
row_blocks_dim=None,
col_blocks_dim=None)
x = block_layer(
inputs=x,
filters=hparams.filter_sizes[2],
blocks=hparams.layer_sizes[2],
strides=[1, 1],
is_training=is_training,
name="block_layer3",
row_blocks_dim=None,
col_blocks_dim=None)
# Calculate the logits and loss.
out = x
outputs = mtf.layers.dense(
out, hidden_dim,
reduced_dims=out.shape.dims[-5:],
activation=mtf.relu, name="dense")
# We assume fixed vocab size for targets
labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
labels = mtf.import_tf_tensor(
mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))
logits = mtf.layers.dense(outputs, classes_dim, name="logits")
soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
loss = mtf.layers.softmax_cross_entropy_with_logits(
logits, soft_targets, classes_dim)
# Reshape logits so it doesn't break inside t2t.
logits = mtf.reshape(
logits,
mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
loss = mtf.reduce_mean(loss)
return logits, loss
@registry.register_hparams
def mtf_resnet_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.no_data_parallelism = True
hparams.use_fixed_batch_size = True
hparams.batch_size = 32
hparams.max_length = 3072
hparams.hidden_size = 256
hparams.label_smoothing = 0.0
# 8-way model-parallelism
hparams.add_hparam("mesh_shape", "batch:8")
hparams.add_hparam("layout", "batch:batch")
hparams.add_hparam("filter_size", 1024)
hparams.add_hparam("num_layers", 6)
# Share weights between input and target embeddings
hparams.shared_embedding = True
hparams.shared_embedding_and_softmax_weights = True
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay"
hparams.learning_rate_warmup_steps = 10000
hparams.add_hparam("d_kv", 32)
# Image related hparams
hparams.add_hparam("img_len", 32)
hparams.add_hparam("num_channels", 3)
hparams.add_hparam("row_blocks", 1)
hparams.add_hparam("col_blocks", 1)
hparams.add_hparam("rows_size", 32)
hparams.add_hparam("cols_size", 32)
# Model-specific parameters
hparams.add_hparam("layer_sizes", [3, 4, 6, 3])
hparams.add_hparam("filter_sizes", [64, 64, 128, 256, 512])
hparams.add_hparam("is_cifar", False)
# Variable init
hparams.initializer = "normal_unit_scaling"
hparams.initializer_gain = 2.
# TODO(nikip): Change optimization scheme?
hparams.learning_rate = 0.1
return hparams
@registry.register_hparams
def mtf_resnet_tiny():
"""Catch bugs locally..."""
hparams = mtf_resnet_base()
hparams.num_layers = 2
hparams.hidden_size = 64
hparams.filter_size = 64
hparams.batch_size = 16
# data parallelism and model-parallelism
hparams.col_blocks = 1
hparams.mesh_shape = "batch:2"
hparams.layout = "batch:batch"
hparams.layer_sizes = [1, 2, 3]
hparams.filter_sizes = [64, 64, 64]
return hparams
@registry.register_hparams
def mtf_resnet_single():
"""Small single parameters."""
hparams = mtf_resnet_tiny()
hparams.mesh_shape = ""
hparams.layout = ""
hparams.hidden_size = 32
hparams.filter_size = 32
hparams.batch_size = 1
hparams.num_encoder_layers = 1
hparams.num_layers = 1
hparams.block_length = 16
return hparams
@registry.register_hparams
def mtf_resnet_base_single():
"""Small single parameters."""
hparams = mtf_resnet_base()
hparams.num_layers = 6
hparams.filter_size = 256
hparams.block_length = 128
hparams.mesh_shape = ""
hparams.layout = ""
return hparams
@registry.register_hparams
def mtf_resnet_base_cifar():
"""Data parallel CIFAR parameters."""
hparams = mtf_resnet_base()
hparams.mesh_shape = "batch:32"
hparams.layoyt = "batch:batch"
hparams.batch_size = 8
hparams.num_layers = 12
hparams.block_length = 256
hparams.hidden_size = 512
hparams.filter_size = 2048
hparams.learning_rate = 0.5
hparams.learning_rate_warmup_steps = 4000
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
hparams.unconditional = True
return hparams
================================================
FILE: tensor2tensor/models/mtf_transformer.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import mesh_tensorflow as mtf
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import modalities
from tensor2tensor.models.research import moe
from tensor2tensor.utils import mtf_model
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
@registry.register_model
class MtfTransformer(mtf_model.MtfModel):
"""Transformer in mesh_tensorflow."""
def __init__(self,
hparams,
mode=tf_estimator.ModeKeys.TRAIN,
problem_hparams=None,
data_parallelism=None,
decode_hparams=None,
**kwargs):
"""Init with assignments of hparams.encoder_layers / decoder_layers."""
# Finalize encoder_layers, decoder_layers
hparams.encoder_layers = (
hparams.encoder_layers * hparams.encoder_replicate_factor)
hparams.decoder_layers = (
hparams.decoder_layers * hparams.decoder_replicate_factor)
super(MtfTransformer, self).__init__(hparams,
mode=mode,
problem_hparams=problem_hparams,
data_parallelism=data_parallelism,
decode_hparams=decode_hparams,
**kwargs)
@property
def batch_dims(self):
hparams = self._hparams
if hparams.outer_batch_size == 0:
return [mtf.Dimension("batch", hparams.batch_size)]
else:
if hparams.batch_size % hparams.outer_batch_size != 0:
raise ValueError(
"hparams.outer_batch_size must divide hparams.batch_size")
return [
mtf.Dimension("outer_batch", hparams.outer_batch_size),
mtf.Dimension("inner_batch",
hparams.batch_size // hparams.outer_batch_size)]
@property
def inputs_vocab_dim(self):
assert self.has_input
return mtf.Dimension("vocab", self._inputs_vocab_size)
@property
def targets_vocab_dim(self):
return mtf.Dimension("vocab", self._targets_vocab_size)
@property
def model_dim(self):
return mtf.Dimension("d_model", self._hparams.d_model)
@property
def max_length_dim(self):
return mtf.Dimension("max_length", self._hparams.max_length)
@property
def length_dim(self):
return mtf.Dimension("length", self._hparams.max_length)
@property
def memory_length_dim(self):
return mtf.Dimension("memory_length", self._hparams.max_length)
@property
def heads_dim(self):
return mtf.Dimension("heads", self._hparams.num_heads)
@property
def kv_dim(self):
return mtf.Dimension("d_kv", self._hparams.d_kv)
@property
def feedforward_dim(self):
return mtf.Dimension("d_ff", self._hparams.d_ff)
@property
def master_dtype(self):
return tf.as_dtype(self._hparams.master_dtype)
@property
def slice_dtype(self):
return tf.as_dtype(self._hparams.slice_dtype)
@property
def activation_dtype(self):
return tf.as_dtype(self._hparams.activation_dtype)
def _import_to_batch_by_length(self, x, name, mesh, hparams):
del hparams
mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim])
x = tf.reshape(x, mtf_shape.to_integer_list)
return mtf.import_fully_replicated(mesh, x, mtf_shape, name=name)
def _embedding_and_softmax_vars(self, mesh):
hparams = self._hparams
if hparams.transformer_type == "encoder":
targets_embedding_var = None
else:
targets_embedding_var = mtf.get_variable(
mesh, "targets_embedding",
mtf.Shape([self.targets_vocab_dim, self.model_dim]),
initializer=tf.random_normal_initializer(),
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
activation_dtype=self.activation_dtype)
if hparams.transformer_type == "decoder":
inputs_embedding_var = None
else:
if hparams.shared_embedding and targets_embedding_var:
inputs_embedding_var = targets_embedding_var
else:
inputs_embedding_var = mtf.get_variable(
mesh, "inputs_embedding",
mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
initializer=tf.random_normal_initializer(),
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
activation_dtype=self.activation_dtype)
if hparams.shared_embedding_and_softmax_weights:
softmax_var = (targets_embedding_var or inputs_embedding_var) * (
self.model_dim.size ** -0.5)
else:
softmax_var = mtf.get_variable(
mesh,
"softmax",
mtf.Shape([self.targets_vocab_dim, self.model_dim]),
initializer=tf.random_normal_initializer(
stddev=self.model_dim.size**-0.5),
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
activation_dtype=self.activation_dtype)
positional_embedding_var = mtf.get_variable(
mesh, "positional_embedding",
mtf.Shape([self.max_length_dim, self.model_dim]),
initializer=tf.random_normal_initializer(),
activation_dtype=self.activation_dtype)
return (inputs_embedding_var, targets_embedding_var,
softmax_var, positional_embedding_var)
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
if noising_spec["type"] == "mask":
# Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
return targets * mtf.cast(
mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
noising_spec["prob"]), targets.dtype)
elif noising_spec["type"] == "random_zipfian":
# Replace a randomly-chosen noising_spec["prob"] of input tokens.
# Rather than drawing the replacement tokens uniformly, we sample from
# a distribution favoring lower token-ids, assuming that the ids have
# been assigned in frequency order. The probability of choosing an
# id is proportional to 1/(id+10)
logits = mtf.log(1.0 / (mtf.range(
targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0))
logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape)
r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
use_noise = mtf.less(
mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"])
return mtf.where(use_noise, r, targets)
elif noising_spec["type"] == "transformer":
# Train a small transformer to fill in masked out values, then
# sample from it.
hparams = self._hparams
if hparams.mode != tf_estimator.ModeKeys.TRAIN:
raise NotImplementedError("Not implemented")
noiser_hparams = copy.copy(self._hparams)
noiser_hparams.del_hparam("mode")
noiser_hparams.override_from_dict(noising_spec["overrides"])
with tf.variable_scope("noiser"):
noiser = MtfTransformer(
noiser_hparams,
mode=hparams.mode,
problem_hparams=self._problem_hparams)
logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access
self._original_features, targets.mesh)
samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
losses.append(loss)
return samples
else:
raise ValueError("unknown noising spec %s" % noising_spec)
def _noisy_targets(self, targets, losses=None):
"""Generate noisy targets for denoising models.
Args:
targets: a Tensor
losses: an optional list onto which to append traning losses
Returns:
a Tensor the same dtype and shape as Targets
"""
hparams = self._hparams
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
nt_train = self._noisy_targets_from_spec(
targets, hparams.noising_spec_train, losses=losses)
if hparams.noising_use_eval_during_train > 0:
nt_eval = self._noisy_targets_from_spec(
targets, hparams.noising_spec_eval)
use_eval_noising = mtf.less(
mtf.random_uniform(targets.mesh, targets.shape - self.length_dim),
hparams.noising_use_eval_during_train)
nt_train = mtf.where(use_eval_noising, nt_eval, nt_train)
return nt_train
else:
return self._noisy_targets_from_spec(targets, hparams.noising_spec_eval)
def _mtf_model_fn(self, features, mesh):
self._original_features = features
features = copy.copy(features)
hparams = self._hparams
extra_losses = []
targets = tf.to_int32(features["targets"])
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
if len(targets.get_shape()) > 2:
tf.logging.info("targets = %s" % targets)
targets = tf.squeeze(targets, [2, 3])
# pad targets to max_length
def pad_to_max_length(x):
extra_length = hparams.max_length - tf.shape(x)[1]
x = tf.pad(x, [[0, 0], [0, extra_length]])
x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
return x
targets = pad_to_max_length(targets)
targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
for key in ["targets_segmentation", "targets_position",
"inputs_segmentation", "inputs_position"]:
if key in features:
features[key] = pad_to_max_length(features[key])
if hparams.decoder_type == "autoregressive":
shifted_targets = mtf.shift(
targets, offset=1, dim=self.length_dim, wrap=False)
elif hparams.decoder_type == "denoising":
shifted_targets = self._noisy_targets(targets, extra_losses)
else:
raise ValueError(
"unknown hparams.decoder_type = %s" % hparams.decoder_type)
if "targets_segmentation" in features:
# "Packed" dataset - keep the examples from seeing each other.
targets_segmentation = self._import_to_batch_by_length(
features["targets_segmentation"], "targets_segmentation",
mesh, hparams)
targets_position = self._import_to_batch_by_length(
features["targets_position"], "targets_position",
mesh, hparams)
decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
targets_segmentation, dtype=self.activation_dtype)
if hparams.decoder_type == "autoregressive":
decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
targets_position, dtype=self.activation_dtype)
else:
targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
if hparams.decoder_type == "autoregressive":
decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
targets_position, dtype=self.activation_dtype)
else:
decoder_self_attention_mask = None
def layer_prepostprocess_dropout(x):
return mtf.dropout(
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
(inputs_embedding_var,
targets_embedding_var,
softmax_var,
positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
if hparams.transformer_type == "decoder":
encoder_output = None
encoder_decoder_attention_mask = None
else:
inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
inputs = pad_to_max_length(inputs)
inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
if "inputs_segmentation" in features:
# "Packed" dataset - keep the examples from seeing each other.
inputs_segmentation = self._import_to_batch_by_length(
features["inputs_segmentation"], "inputs_segmentation",
mesh, hparams)
inputs_position = self._import_to_batch_by_length(
features["inputs_position"], "inputs_position",
mesh, hparams)
encoder_self_attention_mask = (
mtf.layers.attention_mask_same_segment(
inputs_segmentation, dtype=self.activation_dtype))
else:
inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
encoder_self_attention_mask = (
mtf.layers.attention_mask_ignore_padding(
inputs, dtype=self.activation_dtype))
x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
mtf.gather(positional_embedding_var, inputs_position,
self.max_length_dim))
x = layer_prepostprocess_dropout(x)
with tf.variable_scope("encoder"):
x = self._layer_stack(x,
hparams.encoder_layers,
self_attention_mask=encoder_self_attention_mask,
losses=extra_losses)
if hparams.transformer_type == "encdec":
if "inputs_segmentation" in features:
encoder_decoder_attention_mask = (
mtf.layers.attention_mask_same_segment(
targets_segmentation, inputs_segmentation,
dtype=self.activation_dtype))
else:
encoder_decoder_attention_mask = encoder_self_attention_mask
encoder_output = mtf.rename_dimension(
x, self.length_dim.name, self.memory_length_dim.name)
if hparams.transformer_type != "encoder":
# DECODER
x = (mtf.gather(
targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
mtf.gather(
positional_embedding_var, targets_position, self.max_length_dim))
x = layer_prepostprocess_dropout(x)
with tf.variable_scope("decoder"):
x = self._layer_stack(
x,
hparams.decoder_layers,
encoder_output=encoder_output,
self_attention_mask=decoder_self_attention_mask,
encdec_attention_mask=encoder_decoder_attention_mask,
losses=extra_losses)
if (hparams.reshape_logits_hack and
hparams.mode == tf_estimator.ModeKeys.TRAIN):
# For some reason, the logits computation is extremely slow on TPU
# in some cases where the batch size per core is 1. Reshape the logits
# and the targets to double the batch size and halve the length.
# TODO(noam): file a bug.
old_dims = self.batch_dims + [self.length_dim]
new_dims = self.batch_dims[:-1] + [
mtf.Dimension(self.batch_dims[-1].name,
self.batch_dims[-1].size * 2),
mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
x = mtf.reshape(x, new_dims + [self.model_dim])
targets = mtf.reshape(targets, new_dims)
logits = mtf.matmul(x, softmax_var)
if hparams.mode == tf_estimator.ModeKeys.TRAIN:
logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
off_value = hparams.label_smoothing / self._targets_vocab_size
on_value = 1.0 - hparams.label_smoothing + off_value
soft_targets = mtf.one_hot(
targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
dtype=self.activation_dtype)
loss = mtf.layers.softmax_cross_entropy_with_logits(
logits, soft_targets, self.targets_vocab_dim)
weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
loss = mtf.reduce_mean(loss * weights)
for l in extra_losses:
loss += l
if (hparams.reshape_logits_hack and
hparams.mode == tf_estimator.ModeKeys.TRAIN):
logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
logits = mtf.to_float(logits)
return logits, loss
def mtf_model_fn(self, features, mesh):
with tf.variable_scope("transformer"):
logits, loss = self._mtf_model_fn(features, mesh)
# combine batch dims
if len(self.batch_dims) > 1:
combined_batch_dim = mtf.Dimension(
self.batch_dims[0].name, mtf.Shape(self.batch_dims).size)
logits = mtf.reshape(
logits, [combined_batch_dim] + logits.shape.dims[-2:])
return logits, loss
@property
def _targets_vocab_size(self):
targets_vocab_size = self._problem_hparams.vocab_size["targets"]
targets_vocab_size += (-targets_vocab_size) % self._hparams.vocab_divisor
return targets_vocab_size
@property
def _inputs_vocab_size(self):
inputs_vocab_size = self._problem_hparams.vocab_size["inputs"]
inputs_vocab_size += (-inputs_vocab_size) % self._hparams.vocab_divisor
return inputs_vocab_size
def _feedforward_layer(self, x, layer_type, losses=None):
"""Feed-forward layer.
Args:
x: a mtf.Tensor with shape [, length_dim, model_dim]
layer_type: a string
losses: a list to be appended-to
Returns:
a mtf.Tensor with shape [, length_dim, model_dim]
Raises:
ValueError: if hparams make no sense
"""
hparams = self._hparams
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
if layer_type == "drd":
return mtf.layers.dense_relu_dense(
x, self.feedforward_dim, is_training, dropout=hparams.relu_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype)
elif layer_type == "none":
return x
elif layer_type == "moe":
output, loss = moe.transformer_moe_layer_v1(
x,
self.model_dim,
hparams,
hparams.mode == tf_estimator.ModeKeys.TRAIN,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype)
if losses is not None:
losses.append(loss)
return output
elif layer_type == "hmoe":
output, loss = moe.transformer_moe_layer_v2(
x,
self.model_dim,
hparams,
hparams.mode == tf_estimator.ModeKeys.TRAIN,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype)
if losses is not None:
losses.append(loss)
return output
else:
raise ValueError("layer_type not recognized %s" % layer_type)
def _layer_stack(self,
x,
layers,
encoder_output=None,
self_attention_mask=None,
encdec_attention_mask=None,
losses=None,
step_num=None,
encdec_tensors=None,
states=None):
"""Encoder or decoder stack.
Args:
x: a mtf.Tensor with shape [, length_dim, model_dim]
layers: an list of strings
encoder_output: an optional mtf.Tensor with shape
[, encoder_length_dim, model_dim]
self_attention_mask: an optional mtf.Tensor with shape
[batch, length_dim, memory_length_dim] containing values 0 or -inf.
encdec_attention_mask: an optional mtf.Tensor with shape
[batch, length_dim, encoder_length_dim] containing values 0 or -inf.
losses: a list to be appended-to
step_num: an optional mtf integer Scalar (used in incrmenental mode)
encdec_tensors: an optional list of num_layers tuples, each of the form
(q_var, o_var, k, v), (used in incremental mode)
states: an optional list of Tensors (used in incremental mode)
Returns:
a mtf.Tensor with shape [, length_dim, model_dim]
Raises:
ValueError: if hparams make no sense
"""
hparams = self._hparams
is_incremental = (step_num is not None)
mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
is_training = mode == tf_estimator.ModeKeys.TRAIN
def layer_prepostprocess_dropout(x):
if is_incremental:
return x
return mtf.dropout(
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
num_layers = len(layers)
num_layer_norms = num_layers + 1
layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
layer_norm_combined_var = mtf.get_variable(
x.mesh,
"layer_norm_scale",
mtf.Shape([layer_norms_dim, self.model_dim]),
initializer=tf.ones_initializer(),
activation_dtype=x.dtype)
layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
def normalize(x):
scale = layer_norm_vars.pop(0)
variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale
if is_incremental:
states = list(states)
new_states = []
tf.logging.info("states = %s" % (states,))
for lnum, layer_type in enumerate(layers):
with tf.variable_scope("%s_%d" % (layer_type, lnum)):
if layer_type == "att":
# Self attention layer
if is_incremental:
y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
normalize(x),
prev_k=states.pop(0),
prev_v=states.pop(0),
step_num=step_num,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="att")
new_states.append(new_k)
new_states.append(new_v)
x += y
else:
x += layer_prepostprocess_dropout(
mtf.layers.multihead_attention(
normalize(x), None,
self_attention_mask, self.kv_dim, self.heads_dim,
is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="att"))
elif layer_type == "enc_att":
# Encoder-Decoder attention layer
if is_incremental:
# Encoder-Decoder attention layer
q_var, o_var, k, v = encdec_tensors[lnum]
x += mtf.layers.multihead_encdec_attention_incremental(
normalize(x),
q_var, o_var, k, v,
encdec_attention_mask,
name="enc_att")
else:
x += layer_prepostprocess_dropout(
mtf.layers.multihead_attention(
normalize(x), encoder_output,
encdec_attention_mask, self.kv_dim, self.heads_dim,
is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="enc_att"))
elif layer_type == "local_att":
if is_incremental:
y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
normalize(x),
prev_k=states.pop(0),
prev_v=states.pop(0),
step_num=step_num,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="local_att")
new_states.append(new_k)
new_states.append(new_v)
x += y
else:
x += layer_prepostprocess_dropout(
mtf.layers.masked_local_attention_1d(
normalize(x),
self.kv_dim, self.heads_dim, is_training,
window_size=hparams.local_attention_window_size,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
length_per_split=mtf.tensor_dim_to_size_per_split(
hparams.layout, hparams.mesh_shape,
self.max_length_dim),
name="local_att"))
elif layer_type == "compressed_att":
if is_incremental:
raise ValueError("compressed_att incremental not implemented")
else:
x += layer_prepostprocess_dropout(
mtf.layers.multihead_self_attention_memory_compressed(
normalize(x),
mask_right=True,
compression_factor=hparams.compression_factor,
kv_channels=self.kv_dim,
heads=self.heads_dim,
is_training=is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
name="compressed_att"))
else:
if is_incremental:
# insert length dimension.
x_shape = x.shape
shape_with_length = mtf.Shape(
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
+ x_shape.dims[-1:])
x = mtf.reshape(x, shape_with_length)
# ffn layer
x += layer_prepostprocess_dropout(
self._feedforward_layer(normalize(x), layer_type, losses=losses))
if is_incremental:
# remove length dimension
x = mtf.reshape(x, x_shape)
x = layer_prepostprocess_dropout(normalize(x))
assert not layer_norm_vars
if is_incremental:
return x, new_states
else:
return x
def sample(self, features, mesh):
with tf.variable_scope("transformer"):
return self._sample(features, mesh)
def _sample(self, features, mesh):
hparams = self._hparams
(inputs_embedding_var,
targets_embedding_var,
softmax_var,
positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
if hparams.transformer_type == "encdec":
inputs = features["inputs"]
while len(inputs.shape.as_list()) > 2:
inputs = tf.squeeze(inputs, axis=2)
actual_batch_size = tf.shape(inputs)[0]
actual_length = tf.shape(inputs)[1]
inputs = tf.pad(
inputs, [[0, hparams.batch_size - actual_batch_size],
[0, hparams.max_length - actual_length]])
inputs = self._import_to_batch_by_length(
inputs, "inputs", mesh, hparams)
x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
mtf.reshape(positional_embedding_var,
mtf.Shape([self.length_dim, self.model_dim])))
encoder_attention_mask = (
mtf.layers.attention_mask_ignore_padding(
inputs, dtype=self.activation_dtype))
with tf.variable_scope("encoder"):
x = self._layer_stack(x,
hparams.encoder_layers,
self_attention_mask=encoder_attention_mask)
encoder_output = mtf.rename_dimension(
x, self.length_dim.name, self.memory_length_dim.name)
encdec_tensors = []
for layer_num, layer_type in enumerate(hparams.decoder_layers):
if layer_type == "enc_att":
with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
mesh, self.heads_dim, self.model_dim,
self.kv_dim, self.master_dtype, self.slice_dtype,
self.activation_dtype)
k = mtf.einsum(
[encoder_output, k_var],
mtf.Shape(
self.batch_dims + [self.heads_dim,
self.memory_length_dim, self.kv_dim]))
v = mtf.einsum(
[encoder_output, v_var],
mtf.Shape(
self.batch_dims + [self.heads_dim,
self.memory_length_dim, self.kv_dim]))
encdec_tensors.append((q_var, o_var, k, v))
else:
encdec_tensors.append(None)
partial_targets = None
elif hparams.transformer_type == "decoder":
encdec_tensors = None
encoder_output = None
encoder_attention_mask = None
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = features.get("inputs", None)
if partial_targets is None:
partial_targets = features.get("targets", None)
if partial_targets is not None:
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
partial_targets = tf.to_int32(partial_targets)
partial_targets_batch = tf.shape(partial_targets)[0]
partial_targets_length = tf.shape(partial_targets)[1]
partial_targets = tf.pad(
partial_targets, [[0, hparams.batch_size - partial_targets_batch],
[0, hparams.max_length - partial_targets_length]])
partial_targets = self._import_to_batch_by_length(
partial_targets, "partial_targets", mesh, hparams)
else:
raise ValueError(
"hparams.model_type = %s not yet supported"
% hparams.transformer_type)
local_attention_window = mtf.Dimension(
"local_attention_window", hparams.local_attention_window_size)
if hparams.beam_size == 1:
ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
kv_shape = mtf.Shape(self.batch_dims +
[self.heads_dim,
self.memory_length_dim, self.kv_dim])
local_kv_shape = mtf.Shape(self.batch_dims +
[self.heads_dim,
local_attention_window, self.kv_dim])
else:
beam_dim = mtf.Dimension("beam", hparams.beam_size)
ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
kv_shape = mtf.Shape(self.batch_dims +
[beam_dim, self.heads_dim,
self.memory_length_dim, self.kv_dim])
local_kv_shape = mtf.Shape(self.batch_dims +
[beam_dim, self.heads_dim,
local_attention_window, self.kv_dim])
initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
initial_states = []
for layer in hparams.decoder_layers:
if layer == "att":
initial_states.extend(
[mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
elif layer == "local_att":
initial_states.extend(
[mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)
def logits_fn(step_num, ids, states):
"""Produce logits for this step, and new states."""
ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
x = (mtf.gather(targets_embedding_var, ids_this_step,
self.targets_vocab_dim) +
mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
with tf.variable_scope("decoder"):
x, new_states = self._layer_stack(
x,
hparams.decoder_layers,
encdec_attention_mask=encoder_attention_mask,
step_num=step_num,
encdec_tensors=encdec_tensors,
states=states)
logits = mtf.matmul(x, softmax_var)
return logits, new_states
if hparams.beam_size == 1:
temperature = (0.0 if hparams.sampling_method == "argmax"
else hparams.sampling_temp)
return mtf.beam_search.greedy_decode(
logits_fn,
initial_ids,
temperature=temperature,
initial_states=initial_states,
forced_ids=partial_targets,
use_tpu=hparams.use_tpu)
else:
if hparams.transformer_type == "encdec":
input_length = mtf.reduce_sum(
mtf.to_float(mtf.cast(inputs, tf.bool)),
reduced_dim=self.length_dim)
max_input_length = mtf.reduce_max(input_length)
decode_length = mtf.cast(
max_input_length * hparams.decode_length_multiplier
+ hparams.decode_length_constant, tf.int32)
else:
decode_length = None
beams, unused_scores = mtf.beam_search.beam_search(
logits_fn,
initial_ids,
hparams.alpha,
states=initial_states,
decode_length=decode_length,
use_tpu=hparams.use_tpu,
dtype=self.activation_dtype)
return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
@registry.register_hparams
def mtf_transformer_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.no_data_parallelism = True
hparams.use_fixed_batch_size = True
hparams.add_hparam("mtf_mode", True)
hparams.batch_size = 64
hparams.max_length = 256
hparams.add_hparam("d_model", 512)
hparams.add_hparam("d_kv", 128)
hparams.add_hparam("local_attention_window_size", 128)
hparams.label_smoothing = 0.1
# 8-way model-parallelism
hparams.add_hparam("mesh_shape", "model:8")
hparams.add_hparam("layout", "batch:batch;vocab:model;d_ff:model;heads:model")
hparams.add_hparam("num_heads", 8)
hparams.add_hparam("d_ff", 2048)
hparams.add_hparam("encoder_replicate_factor", 1)
hparams.add_hparam("decoder_replicate_factor", 1)
hparams.add_hparam("encoder_layers", ["att", "drd"] * 6)
hparams.add_hparam("decoder_layers", ["att", "enc_att", "drd"] * 6)
hparams.add_hparam("attention_dropout", 0.1)
hparams.add_hparam("relu_dropout", 0.1)
hparams.layer_prepostprocess_dropout = 0.1
# Describes what model architecture:
# "encdec": encoder + autoregressive decoder
# "decoder": single-stack autoregressive sequence model.
# "encoder": single-stack non-autoregressive model
# with equal-length inputs and outputs.
hparams.add_hparam("transformer_type", "encdec")
# What does the decoder do:
# "autoregressive": Decoder left to right
# "denoising": Fills in masked-out values simultaneously
hparams.add_hparam("decoder_type", "autoregressive")
# Parameters describing the noising algorithm for denoising decoders
hparams.add_hparam("noising_spec_train", {"type": "mask", "prob": 0.15})
hparams.add_hparam("noising_spec_eval", {"type": "mask", "prob": 0.15})
# during training, we use the eval noiser with this probability
hparams.add_hparam("noising_use_eval_during_train", 0.1)
# round up vocab sizes to be a multiple of this value
hparams.vocab_divisor = 128
# options are dense_relu_dense, moe, hmoe
hparams.add_hparam("feedforward_layer", "drd")
# If True, then reuse targets_embedding_var * rsqrt(d_model) as softmax_var
# If hparams.transformer_type == "encoder", then there is no targets embedding
# so we reuse the inputs embedding instead.
hparams.shared_embedding_and_softmax_weights = True
# Reuse targets_embedding_var as inputs_embedding_var
# relevant only if hparams.transformer_type == "encdec"
hparams.shared_embedding = True
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "linear_warmup*rsqrt_decay*linear_decay"
hparams.learning_rate_warmup_steps = 10000
hparams.add_hparam("master_dtype", "bfloat16")
hparams.add_hparam("slice_dtype", "float32")
hparams.activation_dtype = "bfloat16"
# These parameters make Transformer model compatible with MtfTransformer
# Do not override these, as mtf_transformer does not support other options.
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
hparams.bottom = {
"inputs": modalities.identity_bottom,
"targets": modalities.identity_bottom,
}
hparams.top = {
"targets": modalities.identity_top,
}
# Parameters for computing the maximum decode length in beam search.
# Maximum decode length is:
# min(max_length,
# decode_length_multiplier * input_length + decode_length_constant)
hparams.add_hparam("decode_length_multiplier", 1.5)
hparams.add_hparam("decode_length_constant", 10.0)
# If nonzero, we split the batch across two tensor-dimensions named
# "outer_batch" and "inner_batch", allowing for splitting across two mesh
# dimensions. This is necessary for hierarchical mixture of experts.
# The two tensor dimensions have sizes hparams.outer_batch_size and
# hparams.batch_size // hparams.outer_batch_size.
hparams.add_hparam("outer_batch_size", 0)
# TODO(noam): file a bug
hparams.add_hparam("reshape_logits_hack", False)
hparams.add_hparam("compression_factor", 4)
return hparams
@registry.register_hparams
def mtf_transformer_base_lm():
hparams = mtf_transformer_base()
hparams.decoder_layers = hparams.encoder_layers
hparams.transformer_type = "decoder"
hparams.label_smoothing = 0.0
hparams.sampling_method = "random"
return hparams
@registry.register_hparams
def mtf_transformer_tiny():
"""Catch bugs locally..."""
hparams = mtf_transformer_base()
hparams.d_model = 128
hparams.d_ff = 512
hparams.batch_size = 8
hparams.encoder_layers = ["att", "drd"] * 2
hparams.decoder_layers = ["att", "enc_att", "drd"] * 2
hparams.num_heads = 8
# data parallelism and model-parallelism
hparams.mesh_shape = "batch:2;model:4"
hparams.activation_dtype = "float32"
return hparams
@registry.register_hparams
def mtf_transformer_tiny_lm():
hparams = mtf_transformer_tiny()
hparams.decoder_layers = hparams.encoder_layers
hparams.transformer_type = "decoder"
hparams.label_smoothing = 0.0
hparams.sampling_method = "random"
return hparams
@registry.register_hparams
def mtf_transformer_tiny_denoising():
hparams = mtf_transformer_tiny_lm()
hparams.decoder_type = "denoising"
hparams.noising_spec_train = ("random_zipfian", 0.3)
hparams.noising_use_eval_during_train = 0.5
hparams.max_length = 1024
return hparams
@registry.register_hparams
def mtf_transformer_single():
hparams = mtf_transformer_tiny()
hparams.mesh_shape = ""
return hparams
@registry.register_hparams
def mtf_transformer_enc_single():
hparams = mtf_transformer_single()
hparams.transformer_type = "encoder"
return hparams
@registry.register_hparams
def mtf_transformer_tiny_8gpu():
hparams = mtf_transformer_tiny()
hparams.mesh_shape = "model:8"
return hparams
def mtf_transformer_paper_lm(size):
"""Config for language-model experiments.
Train these on languagemodel_lm1b32k_packed for 136000 steps (10 epochs)
The size parameter is an integer that controls the number of heads and the
size of the size of the feedforward hidden layers. Increasing size by 1
doubles each of these.
Results:
size params/10^9 log-ppl(per-token)
-1 0.14 3.209
0 0.22 3.119
1 0.37 3.037
2 0.67 2.969
3 1.28 2.912
4 2.48 2.874
5 4.90 2.871
(to get word-level log-ppl, multiply by 1.1078)
Args:
size: an integer
Returns:
a hparams object
"""
n = 2 ** size
hparams = mtf_transformer_base_lm()
hparams.batch_size = 256
hparams.d_model = 1024
hparams.d_ff = int(8192 * n)
hparams.d_kv = 256
hparams.num_heads = int(8 * n)
hparams.shared_embedding_and_softmax_weights = False
# one epoch for languagemodel_lm1b32k_packed = 13600 steps
hparams.learning_rate_decay_steps = 13600
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_m1():
hparams = mtf_transformer_paper_lm(-1)
hparams.mesh_shape = "batch:32"
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_0():
hparams = mtf_transformer_paper_lm(0)
hparams.mesh_shape = "batch:32"
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_1():
hparams = mtf_transformer_paper_lm(1)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_2():
hparams = mtf_transformer_paper_lm(2)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_3():
hparams = mtf_transformer_paper_lm(3)
hparams.mesh_shape = "model:8;batch:16"
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_4():
hparams = mtf_transformer_paper_lm(4)
hparams.mesh_shape = "batch:16;model:32"
return hparams
@registry.register_hparams
def mtf_transformer_paper_lm_5():
hparams = mtf_transformer_paper_lm(5)
hparams.mesh_shape = "batch:16;model:32"
return hparams
def mtf_transformer_paper_tr(size):
"""Config for translation experiments.
Train these on translate_enfr_wmt32k_packed for 154000 steps (3 epochs)
The size parameter is an integer that controls the number of heads and the
size of the size of the feedforward hidden layers. Increasing size by 1
doubles each of these.
Args:
size: an integer
Returns:
a hparams object
"""
n = 2 ** size
hparams = mtf_transformer_base()
hparams.label_smoothing = 0.1
hparams.batch_size = 128
hparams.d_model = 1024
hparams.d_ff = int(4096 * n)
hparams.num_heads = int(8 * n)
hparams.shared_embedding_and_softmax_weights = False
# one epoch for translate_enfr_wmt32k_packed = 51400 steps
hparams.learning_rate_decay_steps = 51400
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_m1():
hparams = mtf_transformer_paper_tr(-1)
hparams.mesh_shape = "batch:32"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0():
hparams = mtf_transformer_paper_tr(0)
hparams.mesh_shape = "batch:32"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0_a32():
hparams = mtf_transformer_paper_tr_0()
hparams.activation_dtype = "float32"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0_nf():
hparams = mtf_transformer_paper_tr_0()
hparams.optimizer_adafactor_factored = False
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_1():
hparams = mtf_transformer_paper_tr(1)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_2():
hparams = mtf_transformer_paper_tr(2)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_3():
hparams = mtf_transformer_paper_tr(3)
hparams.mesh_shape = "model:8;batch:16"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_4():
hparams = mtf_transformer_paper_tr(4)
hparams.mesh_shape = "model:8;batch:16"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0_mesh_8():
hparams = mtf_transformer_paper_tr(0)
hparams.mesh_shape = "batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_4_mesh_16_8():
hparams = mtf_transformer_paper_tr(4)
hparams.mesh_shape = "batch:8;model:16"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_6_mesh_64_8():
# Note: This mesh shape does align well with physical [16, 16, 2] topology.
hparams = mtf_transformer_paper_tr(6)
hparams.mesh_shape = "model:64;batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0_mesh_8_v2():
hparams = mtf_transformer_paper_tr(0)
hparams.batch_size = int(hparams.batch_size / 4)
hparams.mesh_shape = "batch:8"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0_mesh_128():
hparams = mtf_transformer_paper_tr(0)
hparams.batch_size = int(hparams.batch_size * 4)
hparams.mesh_shape = "batch:128"
return hparams
@registry.register_hparams
def mtf_transformer_paper_tr_0_mesh_512():
hparams = mtf_transformer_paper_tr(0)
hparams.batch_size = int(hparams.batch_size * 16)
hparams.mesh_shape = "batch:512"
return hparams
@registry.register_hparams
def mtf_transformer_lm_baseline():
"""Small language model to run on 1 TPU.
Run this on 2x2 on languagemodel_lm1b32k_packed for 272000 steps (10 epochs)
Results:
params/10^9 log-ppl(per-token)
0.14 3.202
Returns:
a hparams
"""
hparams = mtf_transformer_paper_lm(-1)
hparams.batch_size = 128
hparams.learning_rate_decay_steps = 27200 # one epoch on lm1b
hparams.mesh_shape = "batch:8"
return hparams
================================================
FILE: tensor2tensor/models/mtf_transformer2.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mesh_tensorflow as mtf
from mesh_tensorflow.transformer import moe
from mesh_tensorflow.transformer import transformer
from mesh_tensorflow.transformer import transformer_layers
from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import modalities
from tensor2tensor.utils import mtf_model
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
@registry.register_model
class MtfUnitransformer(mtf_model.MtfModel):
"""Single-stack Transformer (Transformer Decoder) in mesh_tensorflow.
Can optionally be autoregressive (language generation) or non-autoregressive
like BERT.
"""
@property
def batch_dims(self):
hparams = self._hparams
if hparams.outer_batch_size == 0:
return [mtf.Dimension("batch", hparams.batch_size)]
else:
if hparams.batch_size % hparams.outer_batch_size != 0:
raise ValueError(
"hparams.outer_batch_size must divide hparams.batch_size")
return [
mtf.Dimension("outer_batch", hparams.outer_batch_size),
mtf.Dimension("inner_batch",
hparams.batch_size // hparams.outer_batch_size)]
def combine_batch_dims(self, x):
if len(self.batch_dims) <= 1:
return x
return mtf.replace_dimensions(
x, self.batch_dims, mtf.combined_dimension(self.batch_dims))
@property
def autoregressive(self):
return self._hparams.autoregressive
@property
def variable_dtype(self):
return mtf.VariableDType(
tf.as_dtype(self._hparams.master_dtype),
tf.as_dtype(self._hparams.slice_dtype),
tf.as_dtype(self._hparams.activation_dtype))
@property
def length_dim(self):
return mtf.Dimension(
"length", self._hparams.length or self._hparams.max_length)
def _import_to_batch_by_length(self, x, name, mesh):
mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim])
x = tf.reshape(x, mtf_shape.to_integer_list)
return mtf.import_fully_replicated(mesh, x, mtf_shape, name=name)
def _import_feature(self, features, mesh, key):
"""Import a feature from the features dictionary into a mtf.Tensor.
Args:
features: a features dictionary
mesh: a Mesh
key: a string
Returns:
a mtf.Tensor with dtype int32 and shape self.batch_dims + self.length_dim
"""
if key not in features:
return None
x = tf.to_int32(features[key])
x = common_layers.expand_squeeze_to_nd(x, 2)
batch_size = mtf.Shape(self.batch_dims).size
x = x[:, :self.length_dim.size]
extra_length = self.length_dim.size - tf.shape(x)[1]
extra_batch = batch_size - tf.shape(x)[0]
x = tf.pad(x, [[0, extra_batch], [0, extra_length]])
mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim])
x = tf.reshape(x, mtf_shape.to_integer_list)
return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)
def model(self):
hparams = self._hparams
if hparams.label_smoothing != 0:
raise NotImplementedError(
"Label smoothing not implemented in unitransformer."
" Do you really want it?")
layer_stack = layer_stack_from_hparams(hparams, "")
if self.autoregressive:
input_vocab_size = self._targets_vocab_size
else:
input_vocab_size = self._inputs_vocab_size
return transformer.Unitransformer(
layer_stack=layer_stack,
d_model=hparams.d_model,
input_vocab_size=input_vocab_size,
output_vocab_size=self._targets_vocab_size,
autoregressive=self.autoregressive,
max_length=hparams.max_length,
shared_embedding_and_softmax_weights=(
hparams.shared_embedding_and_softmax_weights),
z_loss=hparams.z_loss,
layout=hparams.layout,
mesh_shape=hparams.mesh_shape)
def _mtf_model_fn(self, features, mesh):
self._original_features = features
hparams = self._hparams
def import_feature(key):
return self._import_feature(features, mesh, key)
targets = import_feature("targets")
sequence_id = import_feature("targets_segmentation")
if hparams.use_global_position_in_packed_sequence:
position = None
else:
position = import_feature("targets_position")
if self.autoregressive:
inputs = mtf.shift(
targets, offset=1, dim=self.length_dim, wrap=False)
# We should have a 0 at the beginning of each sequence rather than the
# shifted EOS (1) from the previous sequence.
inputs -= mtf.to_int32(mtf.equal(inputs, 1))
else:
inputs = import_feature("inputs")
# TODO(noam): options for bert-style masking here?
model = self.model()
logits, loss = model.call_simple(
inputs=inputs,
targets=targets,
compute_loss=True,
mode=hparams.mode,
variable_dtype=self.variable_dtype,
sequence_id=sequence_id,
position=position)
return logits, loss
def mtf_model_fn(self, features, mesh):
logits, loss = self._mtf_model_fn(features, mesh)
# combine batch dims
logits = self.combine_batch_dims(logits)
return logits, loss
@property
def _targets_vocab_size(self):
targets_vocab_size = self._problem_hparams.vocab_size["targets"]
targets_vocab_size += (-targets_vocab_size) % self._hparams.vocab_divisor
return targets_vocab_size
@property
def _inputs_vocab_size(self):
inputs_vocab_size = self._problem_hparams.vocab_size["inputs"]
inputs_vocab_size += (-inputs_vocab_size) % self._hparams.vocab_divisor
return inputs_vocab_size
def sample(self, features, mesh):
hparams = self._hparams
model = self.model()
def import_feature(key):
return self._import_feature(features, mesh, key)
if self.autoregressive:
# Prepare partial targets.
# In either features["inputs"] or features["targets"].
# We force the outputs to begin with these sequences.
partial_targets = import_feature("inputs")
if partial_targets is None:
partial_targets = import_feature("targets")
if partial_targets:
partial_targets *= mtf.cast(
mtf.not_equal(partial_targets, 1), partial_targets.dtype)
else:
ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
if hparams.beam_size > 1:
raise NotImplementedError(
"Beam search not implemented for unitransformer.")
ret = model.sample_autoregressive(
partial_targets,
temperature=hparams.sampling_temp,
variable_dtype=self.variable_dtype)
return self.combine_batch_dims(ret)
else:
raise ValueError(
"Don't know how to sample from non-autoregressive unitransformer")
@registry.register_model
class MtfBitransformer(MtfUnitransformer):
"""Encoder-Decoder Transformer in mesh_tensorflow."""
def model(self):
hparams = self._hparams
encoder_layer_stack = layer_stack_from_hparams(hparams, "encoder_")
decoder_layer_stack = layer_stack_from_hparams(hparams, "decoder_")
encoder = transformer.Unitransformer(
layer_stack=encoder_layer_stack,
d_model=hparams.d_model,
input_vocab_size=self._inputs_vocab_size,
output_vocab_size=None,
autoregressive=False,
max_length=hparams.max_length,
name="encoder",
layout=hparams.layout,
mesh_shape=hparams.mesh_shape,
)
decoder = transformer.Unitransformer(
layer_stack=decoder_layer_stack,
d_model=hparams.d_model,
input_vocab_size=self._targets_vocab_size,
output_vocab_size=self._targets_vocab_size,
autoregressive=True,
max_length=hparams.max_length,
label_smoothing=hparams.label_smoothing,
shared_embedding_and_softmax_weights=(
hparams.shared_embedding_and_softmax_weights),
z_loss=hparams.z_loss,
name="decoder",
layout=hparams.layout,
mesh_shape=hparams.mesh_shape,
)
return transformer.Bitransformer(
encoder, decoder, shared_embedding=hparams.shared_embedding)
def _mtf_model_fn(self, features, mesh):
self._original_features = features
hparams = self._hparams
def import_feature(key):
return self._import_feature(features, mesh, key)
targets = import_feature("targets")
inputs = import_feature("inputs")
if not inputs:
raise ValueError("inputs feature is missing")
encoder_sequence_id = import_feature("inputs_segmentation")
if not encoder_sequence_id:
encoder_sequence_id = mtf.to_int32(mtf.not_equal(inputs, 0))
decoder_sequence_id = import_feature("targets_segmentation")
if decoder_sequence_id is None:
decoder_sequence_id = mtf.to_int32(mtf.not_equal(targets, 0))
if hparams.use_global_position_in_packed_sequence:
encoder_position = None
decoder_position = None
else:
encoder_position = import_feature("inputs_position")
decoder_position = import_feature("targets_position")
model = self.model()
logits, loss = model.call_simple(
inputs=inputs,
targets=targets,
compute_loss=True,
mode=hparams.mode,
variable_dtype=self.variable_dtype,
encoder_sequence_id=encoder_sequence_id,
decoder_sequence_id=decoder_sequence_id,
encoder_position=encoder_position,
decoder_position=decoder_position)
return logits, loss
def sample(self, features, mesh):
hparams = self._hparams
model = self.model()
inputs = self._import_feature(features, mesh, "inputs")
ret = model.decode(
inputs,
self.variable_dtype,
beam_size=hparams.beam_size,
alpha=hparams.alpha,
temperature=hparams.sampling_temp if hparams.beam_size == 1 else 0,
decode_length_multiplier=hparams.decode_length_multiplier,
decode_length_constant=hparams.decode_length_constant)
return self.combine_batch_dims(ret)
layers_registry = registry.Registries.mtf_layers
# The following functions construct layers based on hyperparmeters
def attention_kwargs_from_hparams(hparams):
return {
"dropout_rate": hparams.attention_dropout,
"extra_logit": 0.0 if hparams.extra_logit else None,
}
@layers_registry.register("self_att")
def self_attention_layer(hparams, prefix):
"""Create self-attention layer based on hyperparameters."""
return transformer_layers.SelfAttention(
num_heads=hparams.get(prefix + "num_heads"),
num_memory_heads=hparams.get(prefix + "num_memory_heads"),
key_value_size=hparams.d_kv,
shared_kv=hparams.get(prefix + "shared_kv", False),
attention_kwargs=attention_kwargs_from_hparams(hparams))
@layers_registry.register("local_self_att")
def local_self_attention_layer(hparams, prefix):
"""Create self-attention layer based on hyperparameters."""
return transformer_layers.LocalSelfAttention(
num_heads=hparams.get(prefix + "num_heads"),
num_memory_heads=hparams.get(prefix + "num_memory_heads"),
radius=hparams.local_attention_radius,
key_value_size=hparams.d_kv,
shared_kv=hparams.get(prefix + "shared_kv", False),
attention_kwargs=attention_kwargs_from_hparams(hparams))
@layers_registry.register("enc_att")
def enc_dec_attention_layer(hparams, prefix):
return transformer_layers.EncDecAttention(
num_heads=hparams.get(prefix + "num_heads"),
num_memory_heads=hparams.get(prefix + "num_memory_heads"),
key_value_size=hparams.d_kv,
shared_kv=hparams.get(prefix + "shared_kv", False),
attention_kwargs=attention_kwargs_from_hparams(hparams))
@layers_registry.register("drd")
def dense_relu_dense_layer(hparams, prefix):
del prefix
return transformer_layers.DenseReluDense(
hidden_size=hparams.d_ff,
dropout_rate=hparams.relu_dropout)
@layers_registry.register("moe_1d")
def moe_1d_layer(hparams, prefix):
del prefix
return moe.MoE1D(num_experts=hparams.moe_num_experts,
hidden_size=hparams.moe_hidden_size)
@layers_registry.register("moe_2d")
def moe_2d_layer(hparams, prefix):
del prefix
return moe.MoE2D(expert_x=hparams.moe_expert_x,
expert_y=hparams.moe_expert_y,
hidden_size=hparams.moe_hidden_size)
def layer_stack_from_hparams(hparams, prefix):
"""Create a layer stack based on the hyperparameter values."""
layers = hparams.get(prefix + "layers")
return transformer.LayerStack(
[layers_registry[l](hparams, prefix) for l in layers],
dropout_rate=hparams.layer_prepostprocess_dropout,
norm_epsilon=hparams.norm_epsilon)
def mtf_transformer2_base():
"""Hyperparameters common to both unitransformer and bitransformer."""
hparams = common_hparams.basic_params1()
hparams.add_hparam("d_model", 1024)
hparams.batch_size = 4
hparams.max_length = 1024
hparams.label_smoothing = 0.0
# a small positive value - this seems important for stability when training
# with bfloat16 activations.
hparams.add_hparam("z_loss", 1e-4)
# hparams applying to both encoder and decoder layer stacks.
hparams.add_hparam("d_ff", 2048)
hparams.add_hparam("d_kv", 128)
hparams.add_hparam("attention_dropout", 0.0)
hparams.add_hparam("relu_dropout", 0.0)
hparams.del_hparam("num_heads")
hparams.del_hparam("num_hidden_layers")
hparams.layer_prepostprocess_dropout = 0.0
hparams.add_hparam("extra_logit", False)
# number of experts for moe_1d
hparams.moe_num_experts = 32
# number of experts for moe_2d = moe_expert_x * moe_expert_y
hparams.add_hparam("moe_expert_x", 8)
hparams.add_hparam("moe_expert_y", 4)
hparams.add_hparam("moe_hidden_size", 32768)
# round up vocab sizes to be a multiple of this value
hparams.vocab_divisor = 128
hparams.optimizer = "Adafactor"
hparams.learning_rate_schedule = "rsqrt_decay*linear_decay"
hparams.learning_rate_warmup_steps = 10000
hparams.add_hparam("master_dtype", "bfloat16")
hparams.add_hparam("slice_dtype", "float32")
hparams.activation_dtype = "bfloat16"
# 8-way model-parallelism
hparams.add_hparam("mesh_shape", "model:8")
hparams.add_hparam("layout", "batch:batch;vocab:model;d_ff:model;heads:model")
# If nonzero, we split the batch across two tensor-dimensions named
# "outer_batch" and "inner_batch", allowing for splitting across two mesh
# dimensions. This is necessary for hierarchical mixture of experts.
# The two tensor dimensions have sizes hparams.outer_batch_size and
# hparams.batch_size // hparams.outer_batch_size.
hparams.add_hparam("outer_batch_size", 0)
hparams.shared_embedding_and_softmax_weights = False
# length for training or decoding - defaults to max_length
hparams.add_hparam("length", 0)
# These parameters make Transformer model compatible with mtf
# Do not override these.
hparams.no_data_parallelism = True
hparams.use_fixed_batch_size = True
hparams.add_hparam("mtf_mode", True)
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
hparams.bottom = {
"inputs": modalities.identity_bottom,
"targets": modalities.identity_bottom,
}
hparams.top = {
"targets": modalities.identity_top,
}
hparams.add_hparam("beam_size", 1)
# If this is True, then in a packed dataset (where exaples are concatenated
# to form longer examples) we use the global position (within the concatenated
# sequence) to compute the positional embedding, instead of the position
# within the individual sequence. This is counterintuitive, but for some
# reason, it keeps the model from diverging.
hparams.add_hparam("use_global_position_in_packed_sequence", True)
return hparams
@registry.register_hparams
def mtf_unitransformer_base():
"""Hyperparameters for single-stack Transformer."""
hparams = mtf_transformer2_base()
hparams.add_hparam("autoregressive", True)
# HYPERPARAMETERS FOR THE SINGLE LAYER STACK
hparams.add_hparam("layers", ["self_att", "drd"] * 6)
# number of heads in multihead attention
hparams.add_hparam("num_heads", 8)
# default of 0 for standard transformer behavior
# 1 means a single set of keys and values that are read by all query heads
hparams.add_hparam("num_memory_heads", 0)
# share attention keys and values
hparams.add_hparam("shared_kv", False)
# if nonzero then use local attention
hparams.add_hparam("local_attention_radius", 128)
return hparams
@registry.register_hparams
def mtf_bitransformer_base():
"""Machine translation base configuration."""
hparams = mtf_transformer2_base()
hparams.max_length = 256
hparams.shared_embedding = True
# HYPERPARAMETERS FOR THE LAYER STACKS
hparams.add_hparam("encoder_layers", ["self_att", "drd"] * 6)
hparams.add_hparam("decoder_layers", ["self_att", "enc_att", "drd"] * 6)
hparams.add_hparam("encoder_num_layers", 6)
hparams.add_hparam("decoder_num_layers", 6)
# number of heads in multihead attention
hparams.add_hparam("encoder_num_heads", 8)
hparams.add_hparam("decoder_num_heads", 8)
hparams.add_hparam("local_attention_radius", 128)
# default of 0 for standard transformer behavior
# 1 means a single set of keys and values that are read by all query heads
hparams.add_hparam("encoder_num_memory_heads", 0)
hparams.add_hparam("decoder_num_memory_heads", 0)
# share attention keys and values
hparams.add_hparam("encoder_shared_kv", False)
hparams.add_hparam("decoder_shared_kv", False)
# Parameters for computing the maximum decode length in beam search.
# Maximum decode length is:
# min(max_length,
# decode_length_multiplier * input_length + decode_length_constant)
hparams.add_hparam("decode_length_multiplier", 1.5)
hparams.add_hparam("decode_length_constant", 10.0)
# used during decoding
hparams.add_hparam("alpha", 0.6)
hparams.sampling_temp = 0.0
return hparams
@registry.register_hparams
def mtf_unitransformer_tiny():
hparams = mtf_unitransformer_base()
hparams.batch_size = 2
hparams.mesh_shape = ""
hparams.d_model = 128
hparams.layers = ["self_att", "drd"] * 2
hparams.num_heads = 4
hparams.d_ff = 512
return hparams
@registry.register_hparams
def mtf_bitransformer_tiny():
"""Small encoder-decoder model for testing."""
hparams = mtf_bitransformer_base()
hparams.batch_size = 2
hparams.mesh_shape = ""
hparams.d_model = 128
hparams.encoder_layers = ["self_att", "drd"] * 2
hparams.decoder_layers = ["self_att", "enc_att", "drd"] * 2
hparams.num_heads = 4
hparams.d_ff = 512
return hparams
@registry.register_hparams
def mtf_unitransformer_all_layers_tiny():
"""Test out all the layers on local CPU."""
hparams = mtf_unitransformer_tiny()
hparams.moe_num_experts = 4
hparams.moe_expert_x = 4
hparams.moe_expert_y = 4
hparams.moe_hidden_size = 512
hparams.layers = ["self_att", "local_self_att", "moe_1d", "moe_2d", "drd"]
return hparams
@registry.register_hparams
def mtf_bitransformer_all_layers_tiny():
"""Test out all the layers on local CPU."""
hparams = mtf_bitransformer_tiny()
hparams.moe_num_experts = 4
hparams.moe_expert_x = 4
hparams.moe_expert_y = 4
hparams.moe_hidden_size = 512
hparams.encoder_layers = [
"self_att", "local_self_att", "moe_1d", "moe_2d", "drd"]
hparams.decoder_layers = [
"self_att", "local_self_att", "enc_att", "moe_1d", "moe_2d", "drd"]
return hparams
@registry.register_hparams
def mtr_lm_dense(sz):
"""Series of architectures for language modeling.
We assume infinite training data, so no dropout necessary.
You can use languagemodel_wiki_noref_v32k_l1k.
(1 epoch = ~46000 steps).
TODO(noam): find a large enough dataset for these experiments.
Args:
sz: an integer
Returns:
a hparams
"""
n = 2 ** sz
hparams = mtf_unitransformer_base()
hparams.d_model = 1024
hparams.max_length = 1024
hparams.batch_size = 128
# Parameters for my_layer_stack()
hparams.num_hidden_layers = 6
hparams.d_ff = 8192 * n
hparams.d_kv = 256
hparams.num_heads = 8 * n
hparams.learning_rate_decay_steps = 65536
hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
hparams.mesh_shape = "batch:32"
return hparams
@registry.register_hparams
def mtr_lm_dense_0():
return mtr_lm_dense(0)
@registry.register_hparams
def mtr_lm_dense_0_h1_16():
hparams = mtr_lm_dense_0()
hparams.decoder_num_heads = 16
hparams.decoder_num_memory_heads = 1
return hparams
@registry.register_hparams
def mtr_lm_dense_1():
return mtr_lm_dense(1)
@registry.register_hparams
def mtr_lm_dense_2():
hparams = mtr_lm_dense(2)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtr_lm_dense_3():
hparams = mtr_lm_dense(3)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtr_lm_v1():
"""Model incorporating mixture-of-experts, local and global attention.
~6B parameters
32 experts in 3 hierarchichal moe layers.
Returns:
a hparams
"""
hparams = mtr_lm_dense(0)
hparams.layers = (["local_self_att", "local_self_att", "drd",
"self_att", "drd", "local_self_att",
"local_self_att", "moe_2d"] * 4)[:-1]
hparams.d_kv = 128
hparams.moe_expert_x = 8
hparams.moe_expert_y = 4
hparams.moe_hidden_size = 32768
hparams.d_ff = 2048
hparams.num_memory_heads = 0
hparams.mesh_shape = "b0:4;b1:8"
hparams.layout = "outer_batch:b0;inner_batch:b1,expert_x:b1,expert_y:b0"
hparams.outer_batch_size = 4
return hparams
@registry.register_hparams
def mtr_lm_v1_h1_8():
"""Version for fast decoding."""
hparams = mtr_lm_v1()
hparams.num_memory_heads = 1
return hparams
def mtr_tr_dense(sz):
"""Series of machine translation models.
All models are trained on sequences of 256 tokens.
You can use the dataset translate_enfr_wmt32k_packed.
154000 steps = 3 epochs.
Args:
sz: an integer
Returns:
a hparams
"""
n = 2 ** sz
hparams = mtf_bitransformer_base()
hparams.d_model = 1024
hparams.max_length = 256
hparams.batch_size = 128
hparams.d_ff = int(4096 * n)
hparams.d_kv = 128
hparams.encoder_num_heads = int(8 * n)
hparams.decoder_num_heads = int(8 * n)
# one epoch for translate_enfr_wmt32k_packed = 51400 steps
hparams.learning_rate_decay_steps = 51400
hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
hparams.mesh_shape = "batch:32"
hparams.label_smoothing = 0.1
hparams.layer_prepostprocess_dropout = 0.1
hparams.attention_dropout = 0.1
hparams.relu_dropout = 0.1
return hparams
@registry.register_hparams
def mtr_tr_dense_0():
return mtr_tr_dense(0)
@registry.register_hparams
def mtr_tr_dense_1():
return mtr_tr_dense(1)
@registry.register_hparams
def mtr_tr_dense_2():
hparams = mtr_tr_dense(2)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtr_tr_dense_3():
hparams = mtr_tr_dense(3)
hparams.mesh_shape = "model:4;batch:8"
return hparams
@registry.register_hparams
def mtr_tr_dense_3_88():
hparams = mtr_tr_dense(3)
hparams.mesh_shape = "model:8;batch:16"
return hparams
@registry.register_hparams
def mtr_tr_dense_3_fast():
hparams = mtr_tr_dense_3()
hparams.local_attention_radius = 32
hparams.decoder_num_heads = 128
hparams.decoder_num_memory_heads = 8
return hparams
def mtr_tr_dense_local(sz):
"""With local self-attention in the decoder."""
hparams = mtr_tr_dense(sz)
hparams.decoder_layers = ["local_self_att", "enc_att", "drd"] * 6
hparams.local_attention_radius = 32
return hparams
@registry.register_hparams
def mtr_tr_dense_local_0():
return mtr_tr_dense_local(0)
@registry.register_hparams
def mtr_tr_dense_local_0_w8():
hparams = mtr_tr_dense_local_0()
hparams.local_attention_radius = 8
return hparams
@registry.register_hparams
def mtr_tr_dense_local_0_h1_16():
hparams = mtr_tr_dense_local_0()
hparams.decoder_num_heads = 16
hparams.decoder_num_memory_heads = 1
return hparams
@registry.register_hparams
def mtr_tr_dense_local_0_h1_16_shared():
hparams = mtr_tr_dense_local_0_h1_16()
hparams.shared_embedding_and_softmax_weights = True
return hparams
@registry.register_hparams
def mtr_tr_dense_local_0_h1_8_kv256():
hparams = mtr_tr_dense_local_0()
hparams.decoder_num_heads = 8
hparams.decoder_num_memory_heads = 1
hparams.d_kv = 256
return hparams
@registry.register_hparams
def mtr_tr_dense_local_0_h1_16_shared_kv():
hparams = mtr_tr_dense_local_0_h1_16()
hparams.decoder_shared_kv = True
return hparams
@registry.register_hparams
def mtr_tr_dense_0_h4():
hparams = mtr_tr_dense_0()
hparams.decoder_num_heads = 4
return hparams
@registry.register_hparams
def mtr_tr_dense_0_h16():
hparams = mtr_tr_dense_0()
hparams.decoder_num_heads = 16
return hparams
@registry.register_hparams
def mtr_tr_dense_0_extra_logit():
hparams = mtr_tr_dense_0()
hparams.extra_logit = True
return hparams
@registry.register_hparams
def mtr_tr_dense_0_h1_8():
hparams = mtr_tr_dense_0()
hparams.decoder_num_memory_heads = 1
return hparams
@registry.register_hparams
def mtr_tr_dense_0_h1_1():
hparams = mtr_tr_dense_0()
hparams.decoder_num_heads = 1
return hparams
@registry.register_hparams
def mtr_tr_dense_0_h1_16():
hparams = mtr_tr_dense_0()
hparams.decoder_num_heads = 16
hparams.decoder_num_memory_heads = 1
return hparams
@registry.register_hparams
def mtr_tr_dense_0_h2_16():
hparams = mtr_tr_dense_0()
hparams.decoder_num_heads = 16
hparams.decoder_num_memory_heads = 2
return hparams
@registry.register_hparams
def mtr_tr_dense_0_shared_kv():
hparams = mtr_tr_dense_0()
hparams.decoder_shared_kv = True
return hparams
@registry.register_hparams
def mtr_tr_enfr_v0():
# good parameters for wmt-en-fr
hparams = mtr_tr_dense_local_0_h1_16()
return hparams
@registry.register_hparams
def mtr_tr_ende_v0():
# good parameters for wmt-en-de
hparams = mtr_tr_dense_local_0_h1_16()
hparams.learning_rate_decay_steps = 20000
hparams.shared_embedding_and_softmax_weights = True
hparams.layer_prepostprocess_dropout = 0.2
return hparams
@registry.register_hparams
def mtr_tr_ende_deep():
hparams = mtr_tr_ende_v0()
hparams.decoder_num_heads = 8
hparams.encoder_num_heads = 4
hparams.d_ff = 2048
hparams.encoder_num_layers = 12
hparams.decoder_num_layers = 12
return hparams
================================================
FILE: tensor2tensor/models/mtf_transformer_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Transformer on Mesh TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mesh_tensorflow as mtf
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import mtf_transformer
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
# Constants shared between all functions.
BATCH_SIZE = 2
INPUT_LENGTH = 6
TARGET_LENGTH = 6
VOCAB_SIZE = 128
def get_model(hparams=None, mode=tf_estimator.ModeKeys.TRAIN,
has_input=True, model_cls=mtf_transformer.MtfTransformer):
if hparams is None:
hparams = mtf_transformer.mtf_transformer_single()
hparams.max_length = INPUT_LENGTH
hparams.batch_size = BATCH_SIZE
p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE,
VOCAB_SIZE,
hparams)
if not has_input:
del p_hparams.modality["inputs"]
hparams.problem_hparams = p_hparams
inputs = np.random.randint(
VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1))
targets = np.random.randint(
VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
features = {
"targets": tf.constant(targets, dtype=tf.int32, name="targets"),
"target_space_id": tf.constant(1, dtype=tf.int32)
}
if has_input:
features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs")
return model_cls(hparams, mode, p_hparams), features, hparams
def get_placement_mesh(hparams):
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
mesh_devices = [""] * mesh_shape.size
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, hparams.layout, mesh_devices)
return mesh, mesh_impl
class MtfTransformerTest(tf.test.TestCase):
def testMtfTransformer(self):
hparams = mtf_transformer.mtf_transformer_single()
model, features, hparams = get_model(hparams)
hparams.mesh_shape = ""
hparams.layout = ""
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
def testMtfTransformerDataParallel(self):
hparams = mtf_transformer.mtf_transformer_single()
model, features, hparams = get_model(hparams)
hparams.mesh_shape = "all:2"
hparams.layout = "batch:all"
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
def testMtfTransformerModelParallel(self):
hparams = mtf_transformer.mtf_transformer_single()
model, features, hparams = get_model(hparams)
hparams.mesh_shape = "all:2"
hparams.layout = "length:all"
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
def testMtfTransformerDataModelParallel(self):
hparams = mtf_transformer.mtf_transformer_single()
model, features, hparams = get_model(hparams)
hparams.mesh_shape = "batch:2;model:2"
hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
def testMtfTransformerEncoderDataModelParallel(self):
hparams = mtf_transformer.mtf_transformer_enc_single()
model, features, hparams = get_model(hparams)
hparams.mesh_shape = "batch:2;model:2"
hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
mesh, mesh_impl = get_placement_mesh(hparams)
logits, _ = model.mtf_model_fn(features, mesh)
lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
tf_group = lowering.copy_masters_to_slices()
tf_logits = lowering.export_to_tf_tensor(logits)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
session.run(tf_group)
res = session.run(tf_logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/neural_architecture_search/README.md
================================================
This directory contains the configurable model code used in the Evolved
Transformer paper (https://arxiv.org/abs/1901.11117). It can be used to train
models in the search space as was done in the paper.
================================================
FILE: tensor2tensor/models/neural_architecture_search/__init__.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: tensor2tensor/models/neural_architecture_search/nas_layers.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bank of layers for Translation NAS searches.
All encoder layers are registered in the global LayerRegistry ENCODER_LAYERS.
All decoder layers are registered on the global LayerRegistry DECODER_LAYERS.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensor2tensor.layers import common_attention
import tensorflow.compat.v1 as tf
# Registry layer keys.
ATTEND_TO_ENCODER_REGISTRY_KEY = "attend_to_encoder"
ATTENTION_32_HEADS_REGISTRY_KEY = "attention_32_heads"
ATTENTION_16_HEADS_REGISTRY_KEY = "attention_16_heads"
ATTENTION_4_HEADS_REGISTRY_KEY = "attention_4_heads"
DEPTHWISE_CONV_3X1_REGISTRY_KEY = "depthwise_conv_3x1"
DEPTHWISE_CONV_5X1_REGISTRY_KEY = "depthwise_conv_5x1"
DEPTHWISE_CONV_7X1_REGISTRY_KEY = "depthwise_conv_7x1"
DILATED_CONV_3X1_REGISTRY_KEY = "dilated_conv_3x1"
DILATED_CONV_5X1_REGISTRY_KEY = "dilated_conv_5x1"
GATED_LINEAR_UNIT_REGISTRY_KEY = "gated_linear_unit"
IDENTITY_REGISTRY_KEY = "identity"
# Lightweight convolution naming convention uses "R_X" where X is the variable
# reduction factor.
LIGHTWEIGHT_CONV_3X1_R_1_REGISTRY_KEY = "lightweight_conv_3x1_r_1"
LIGHTWEIGHT_CONV_3X1_R_4_REGISTRY_KEY = "lightweight_conv_3x1_r_4"
LIGHTWEIGHT_CONV_3X1_R_16_REGISTRY_KEY = "lightweight_conv_3x1_r_16"
LIGHTWEIGHT_CONV_5X1_R_1_REGISTRY_KEY = "lightweight_conv_5x1_r_1"
LIGHTWEIGHT_CONV_5X1_R_4_REGISTRY_KEY = "lightweight_conv_5x1_r_4"
LIGHTWEIGHT_CONV_5X1_R_16_REGISTRY_KEY = "lightweight_conv_5x1_r_16"
LIGHTWEIGHT_CONV_7X1_R_1_REGISTRY_KEY = "lightweight_conv_7x1_r_1"
LIGHTWEIGHT_CONV_7X1_R_4_REGISTRY_KEY = "lightweight_conv_7x1_r_4"
LIGHTWEIGHT_CONV_7X1_R_16_REGISTRY_KEY = "lightweight_conv_7x1_r_16"
LIGHTWEIGHT_CONV_15X1_R_1_REGISTRY_KEY = "lightweight_conv_15x1_r_1"
LIGHTWEIGHT_CONV_15X1_R_4_REGISTRY_KEY = "lightweight_conv_15x1_r_4"
LIGHTWEIGHT_CONV_15X1_R_16_REGISTRY_KEY = "lightweight_conv_15x1_r_16"
SEPARABLE_CONV_3X1_REGISTRY_KEY = "separable_conv_3x1"
SEPARABLE_CONV_5X1_REGISTRY_KEY = "separable_conv_5x1"
SEPARABLE_CONV_7X1_REGISTRY_KEY = "separable_conv_7x1"
SEPARABLE_CONV_9X1_REGISTRY_KEY = "separable_conv_9x1"
SEPARABLE_CONV_11X1_REGISTRY_KEY = "separable_conv_11x1"
SEPARABLE_CONV_13X1_REGISTRY_KEY = "separable_conv_13x1"
SEPARABLE_CONV_15X1_REGISTRY_KEY = "separable_conv_15x1"
STANDARD_CONV_1X1_REGISTRY_KEY = "standard_conv_1x1"
STANDARD_CONV_3X1_REGISTRY_KEY = "standard_conv_3x1"
STANDARD_CONV_5X1_REGISTRY_KEY = "standard_conv_5x1"
STANDARD_ATTENTION_REGISTRY_KEY = "standard_attention"
class TranslationLayer(object):
"""Interface for the layers used in the Translation search space."""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def _apply_logic(self, input_tensor, output_depth, hparams, var_scope_suffix,
nonpadding, mask_future, **kwargs):
"""Applies the layer specific logic to the `input_tensor`.
This is called by `apply_layer()` to apply the subclass specific logic to
the preprocessed `input_tensor`.
Args:
input_tensor: [batch_size, batch time_steps, embedding_depth] tensor.
output_depth: Depth of the output tensor.
hparams: Hyperparameters for the layer.
var_scope_suffix: Suffix appended to the end of the variable scope.
nonpadding: a [batch_size, batch time_steps] tensor with 1 where each
batch member has sequence information and 0 everywhere else. This is
used to mask out the irrelevant padded portions of the input.
mask_future: Boolean. If False, information moves across the
spatial/temporal dimension freely. If True, each timestep can only
process the information that has come before it.
**kwargs: Subclass-specific arguments.
Returns:
logic_output: [batch_size, batch time_steps, output_depth] tensor output
of the logic.
"""
def apply_layer(self,
input_tensor,
residual_tensor,
output_depth,
activation,
hparams,
var_scope_suffix,
nonpadding,
mask_future,
layer_preprocess_fn=None,
postprocess_dropout=True,
**kwargs):
"""Applies the layer to the input.
Also applies pad masking, preprocessing, postprocessing, and nonlinearity.
Args:
input_tensor: [batch_size, batch time_steps, embedding_depth] tensor.
residual_tensor: Tensor that gets added to the output residually if
`layer_postprocess` is True.
output_depth: Depth of the output tensor.
activation: Activation to be applied to the `layer_output`. If None, no
activation will be applied.
hparams: Hyperparameters for the layer.
var_scope_suffix: Suffix appended to the end of the variable scope.
nonpadding: a [batch_size, batch time_steps] tensor with 1 where each
batch member has sequence information and 0 everywhere else. This is
used to mask out the irrelevant padded portions of the input.
mask_future: Boolean. If False, information moves across the
spatial/temporal dimension freely. If True, each timestep can only
process the information that has come before it.
layer_preprocess_fn: Preprocess function applied to the input.
postprocess_dropout: Whether or not to apply dropout.
**kwargs: Arguments used by specific TranslationLayers.
Returns:
layer_output: The output of the layer.
"""
input_depth = input_tensor.shape.as_list()[-1]
layer_output = input_tensor
if nonpadding is not None:
nonpadding_input_tiled = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, input_depth])
layer_output *= nonpadding_input_tiled
if layer_preprocess_fn:
layer_output = layer_preprocess_fn(layer_output)
if nonpadding is not None:
layer_output *= nonpadding_input_tiled
layer_output = self._apply_logic(layer_output, output_depth, hparams,
var_scope_suffix, nonpadding, mask_future,
**kwargs)
if activation:
layer_output = activation(layer_output)
if postprocess_dropout:
layer_output = tf.nn.dropout(layer_output, 1 - hparams.relu_dropout)
if residual_tensor is not None:
layer_output += residual_tensor
# Remove the output padding items.
if nonpadding is not None:
nonpadding_output_tiled = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, output_depth])
layer_output *= nonpadding_output_tiled
return layer_output
@abc.abstractmethod
def num_params(self, input_depth, output_depth, **kwargs):
"""Returns num_params in the layer for the given input and output depths.
NOTE: This does not include layer norm parameters that appear in
layer_preprocess or layer_postprocess!
Args:
input_depth: The depth of the input.
output_depth: The depth of the output.
**kwargs: TranslationLayer specific arguments.
"""
class LayerRegisteredError(Exception):
"""Layer name is already used in LayerRegistry."""
class LayerRegistry(object):
"""Registry of TranslationLayers.
The registry is a mapping of string names to TranslationLayers. Layers can be
added to the registry via `registry_layer()` and can be accessed via `get()`.
"""
def __init__(self):
self._layers = {}
def register_layer(self, name, translation_layer):
"""Register a TranslationLayer under the key `name`."""
if name in self._layers and self._layers[name] != translation_layer:
raise LayerRegisteredError(
"Already registered %s in layer registry with a different object!" %
name)
self._layers[name] = translation_layer
def get(self, name):
return self._layers[name]
def get_layer_names(self):
return sorted(six.iterkeys(self._layers))
DECODER_LAYERS = LayerRegistry()
ENCODER_LAYERS = LayerRegistry()
class ConvLayerBase(TranslationLayer):
"""Convolution TranslationLayer base class."""
def __init__(self, conv_type, conv_width, dilation_rate):
self._conv_type = conv_type
self._conv_width = conv_width
self._dilation_rate = dilation_rate
def _conv_function(self, input_tensor, output_depth, padding):
"""Conv function that will be applied to the input tensor."""
raise NotImplementedError()
def _apply_logic(self, input_tensor, output_depth, hparams, var_scope_suffix,
nonpadding, mask_future, **unused_kwargs):
"""Applies conv logic to `input_tensor`."""
with tf.variable_scope("%s_conv_%s" % (self._conv_type, var_scope_suffix)):
if mask_future:
# Pad shift the inputs so that temporal information does not leak. This
# must be used in tandem with VALID padding.
pad_amount = int(self._conv_width - 1) * self._dilation_rate
logic_output = tf.pad(
input_tensor, paddings=[[0, 0], [pad_amount, 0], [0, 0]])
padding = "VALID"
else:
logic_output = input_tensor
padding = "SAME"
logic_output = tf.expand_dims(logic_output, 2)
logic_output = self._conv_function(logic_output, output_depth, padding)
logic_output = tf.squeeze(logic_output, 2)
return logic_output
class SeparableConvLayer(ConvLayerBase):
"""Separable convolution TranslationLayer base class."""
def __init__(self, conv_width):
super(SeparableConvLayer, self).__init__("separable", conv_width, 1)
def _conv_function(self, input_tensor, output_depth, padding):
conv_output = tf.squeeze(input_tensor, 2)
separable_conv_1d = tf.layers.SeparableConv1D(
output_depth,
self._conv_width,
padding=padding,
name="separable_conv_%sx1" % self._conv_width)
conv_output = separable_conv_1d.apply(conv_output)
return tf.expand_dims(conv_output, 2)
def num_params(self, input_depth, output_depth, **unused_kwargs):
return (self._conv_width * input_depth + input_depth * output_depth +
output_depth)
class StandardConvLayer(ConvLayerBase):
"""Standard convolutional TranslationLayer base class."""
def __init__(self, conv_width):
super(StandardConvLayer, self).__init__("standard", conv_width, 1)
def _conv_function(self, input_tensor, output_depth, padding):
return tf.layers.conv2d(
input_tensor,
output_depth, [self._conv_width, 1],
padding=padding,
name="conv_%sx1" % self._conv_width)
def num_params(self, input_depth, output_depth, **unused_kwargs):
return self._conv_width * input_depth * output_depth + output_depth
def calculate_depthwise_channel_multiplier(input_depth, output_depth):
"""Calculates channel multiplier for depthwise convolution."""
# Check to see if the output_depth >= input_depth
# and output_depth % input_depth == 0. If this is the case then we
# can satify the output_depth constraint, so the channel multiplier
# will be set accordingly.
if output_depth >= input_depth and output_depth % input_depth == 0:
return output_depth // input_depth
return 1
class DepthwiseConvLayer(ConvLayerBase):
"""Depthwise convolution TranslationLayer base class."""
def __init__(self, conv_width):
super(DepthwiseConvLayer, self).__init__("depthwise", conv_width, 1)
def _conv_function(self, input_tensor, output_depth, padding):
input_depth = input_tensor.shape.as_list()[-1]
if not ((output_depth >= input_depth) and
(output_depth % input_depth == 0)):
raise ValueError(
"Depthwise layer output_depth (%s) must be greater or equal to and "
"a multiple of the depth of the "
"input tensor (%s)." % (output_depth, input_depth))
channel_multiplier = calculate_depthwise_channel_multiplier(
input_depth, output_depth)
kernel = tf.get_variable(
"kernel", [self._conv_width, 1, input_depth, channel_multiplier])
return tf.nn.depthwise_conv2d(
input_tensor,
kernel, [1, 1, 1, 1],
padding=padding,
name="depthwise_conv_%sx1" % str(self._conv_width))
def num_params(self, input_depth, output_depth, **unused_kwargs):
channel_multiplier = calculate_depthwise_channel_multiplier(
input_depth, output_depth)
return self._conv_width * input_depth * channel_multiplier
class LightweightConvLayer(ConvLayerBase):
"""Lightweight convolution TranslationLayer base class."""
def __init__(self, conv_width, num_repeat):
super(LightweightConvLayer, self).__init__("depthwise", conv_width, 1)
self._num_repeat = num_repeat
def _conv_function(self, input_tensor, output_depth, padding):
input_depth = input_tensor.shape.as_list()[-1]
if not ((output_depth >= input_depth) and
(output_depth % input_depth == 0)):
raise ValueError(
"Depthwise layer output_depth (%s) must be greater or equal to and "
"a multiple of the depth of the "
"input tensor (%s)." % (output_depth, input_depth))
channel_multiplier = calculate_depthwise_channel_multiplier(
input_depth, output_depth)
num_input_variables = input_depth // self._num_repeat
kernel_base = tf.get_variable(
"kernel_base",
[self._conv_width, 1, num_input_variables, channel_multiplier])
kernel = tf.concat([kernel_base] * self._num_repeat, axis=2)
num_nonrepeated_variables = input_depth % self._num_repeat
if num_nonrepeated_variables:
nonrepeated_variables = tf.get_variable(
"nonrepeated_kernel_variables",
[self._conv_width, 1, num_nonrepeated_variables, channel_multiplier])
kernel = tf.concat([kernel, nonrepeated_variables], axis=2)
kernel = tf.nn.softmax(kernel, axis=0)
return tf.nn.depthwise_conv2d(
input_tensor,
kernel, [1, 1, 1, 1],
padding=padding,
name="lightweight_conv_%sx1_r_%s" % (str(self._conv_width),
str(self._num_repeat)))
def num_params(self, input_depth, output_depth, **unused_kwargs):
channel_multiplier = calculate_depthwise_channel_multiplier(
input_depth, output_depth)
return self._conv_width * (input_depth // self._num_repeat + (
input_depth % self._num_repeat)) * channel_multiplier
class DilatedConvLayer(ConvLayerBase):
"""Dilated convolution TranslationLayer base class."""
def __init__(self, conv_width):
super(DilatedConvLayer, self).__init__("dilated", conv_width, 2)
def _conv_function(self, input_tensor, output_depth, padding):
input_depth = input_tensor.shape.as_list()[-1]
kernel = tf.get_variable("kernel",
[self._conv_width, 1, input_depth, output_depth])
return tf.nn.atrous_conv2d(
input_tensor,
kernel,
self._dilation_rate,
padding=padding,
name="dilated_conv_%sx1" % str(self._conv_width))
def num_params(self, input_depth, output_depth, **unused_kwargs):
return self._conv_width * input_depth * output_depth
class AttentionLayer(TranslationLayer):
"""Attention layer base class."""
def __init__(self,
hidden_dim_multiplier,
project_q,
project_k,
project_v,
num_heads=None):
self._hidden_dim_multiplier = hidden_dim_multiplier
self._project_q = project_q
self._project_k = project_k
self._project_v = project_v
self._num_heads = num_heads
def _apply_logic(self,
input_tensor,
output_depth,
hparams,
var_scope_suffix,
nonpadding,
mask_future,
decoder_self_attention_bias=None,
attention_dropout_broadcast_dims=None,
**kwargs):
"""Applies attention logic to `input_tensor`."""
with tf.variable_scope("standard_attention_layer_" + var_scope_suffix):
hidden_depth = int(
input_tensor.shape.as_list()[-1] * self._hidden_dim_multiplier)
attention_bias = decoder_self_attention_bias
# TODO(davidso): This dropout rate differs from the other layers. This
# should be fixed so that they all use the same dropout
# rate.
num_heads = self._num_heads
if num_heads is None:
num_heads = hparams.num_heads
logic_output = common_attention.multihead_attention(
input_tensor,
None,
attention_bias,
hidden_depth,
hidden_depth,
output_depth,
num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
dropout_broadcast_dims=attention_dropout_broadcast_dims)
return logic_output
def num_params(self, input_depth, output_depth, **unused_kwargs):
# First account for the hidden to output projection params.
hidden_depth = input_depth * self._hidden_dim_multiplier
output_params = hidden_depth * output_depth
# Next account for all the hidden projections.
num_projections = sum([self._project_q, self._project_k, self._project_v])
return input_depth * hidden_depth * num_projections + output_params
class AttendToEncoderLayerBase(TranslationLayer):
"""Attend to encoder base, with configurable encoder attend points."""
def _determine_encoder_cell_index(self, cell_number, num_encoder_cells):
"""Determine the encoder cell index to attend to."""
raise NotImplementedError()
def _apply_logic(self,
input_tensor,
output_depth,
hparams,
var_scope_suffix,
nonpadding,
mask_future,
encoder_decoder_attention_bias,
encoder_cell_outputs,
cell_number,
attention_dropout_broadcast_dims=None,
**unused_kwargs):
"""Applies attention logic to `input_tensor`."""
with tf.variable_scope("attend_to_encoder_layer_" + var_scope_suffix):
hidden_depth = int(input_tensor.shape.as_list()[-1])
num_encoder_cells = len(encoder_cell_outputs)
encoder_cell_index = self._determine_encoder_cell_index(
cell_number, num_encoder_cells)
encoder_layer = encoder_cell_outputs[encoder_cell_index]
# TODO(davidso): This dropout rate differs from the other layers. This
# should be fixed so that they all use the same dropout
# rate.
logic_output = common_attention.multihead_attention(
input_tensor,
encoder_layer,
encoder_decoder_attention_bias,
hidden_depth,
hidden_depth,
output_depth,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
dropout_broadcast_dims=attention_dropout_broadcast_dims)
return logic_output
# Assumes uniform encoder output depths.
def num_params(self, input_depth, output_depth, **kwargs):
try:
encoder_depth = kwargs["encoder_depth"]
except KeyError:
raise ValueError("`encoder_depth` must be in kwargs passed to "
"AttendToEncoder.num_params().")
hidden_depth = input_depth
# The number of params is comprised of the projection from the input tensor
# to its hidden tensor, the two encoder tensor projects to its hidden
# tensors, and the projection from the hidden concatenation to the output
# tensor.
return (input_depth * hidden_depth + 2 * encoder_depth * hidden_depth +
hidden_depth * output_depth)
class AttendToEncoderTopDownLayer(AttendToEncoderLayerBase):
"""Attend to the encoder starting with the highest layer, then moving down.
This allows the decoder to see higher level features first and then
eventually move on to incorporate lower level information.
"""
def __init__(self, delay, increment_step):
self.delay = delay
self.increment_step = increment_step
def _determine_encoder_cell_index(self, cell_number, num_encoder_cells):
"""Attend to final encoder cell output first, then move down."""
return max(
0, num_encoder_cells -
max(0, (cell_number - self.delay) * self.increment_step) - 1)
class GatedLinearUnitLayer(TranslationLayer):
"""Gated Linaer Unit Layer."""
def __init__(self):
pass
def _apply_logic(self, input_tensor, output_depth, hparams, var_scope_suffix,
nonpadding, mask_future, **unused_kwargs):
values = tf.layers.dense(input_tensor, output_depth)
gates = tf.layers.dense(
input_tensor, output_depth, activation=tf.nn.sigmoid)
return values * gates
def num_params(self, input_depth, output_depth, **unused_kwargs):
return input_depth * output_depth * 2 + output_depth * 2
class IdentityLayer(TranslationLayer):
"""Identity TranslationLayer."""
def _apply_logic(self, input_tensor, output_depth, hparams, var_scope_suffix,
nonpadding, mask_future, **unused_kwargs):
input_depth = input_tensor.shape.as_list()[-1]
if output_depth != input_depth:
raise ValueError(
"Identity layer output_depth (%s) must be equal to the depth of the "
"input tensor (%s)." % (output_depth, input_depth))
return input_tensor
def num_params(self, input_depth, output_depth, **unused_kwargs):
return 0
def register_encoder_decoder_layer(name, translation_layer):
ENCODER_LAYERS.register_layer(name, translation_layer)
DECODER_LAYERS.register_layer(name, translation_layer)
# Register all strictly decoder layers.
DECODER_LAYERS.register_layer(
ATTEND_TO_ENCODER_REGISTRY_KEY,
AttendToEncoderTopDownLayer(delay=0, increment_step=0))
# Register all encoder and decoder layers.
register_encoder_decoder_layer(IDENTITY_REGISTRY_KEY, IdentityLayer())
register_encoder_decoder_layer(SEPARABLE_CONV_3X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=3))
register_encoder_decoder_layer(SEPARABLE_CONV_5X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=5))
register_encoder_decoder_layer(SEPARABLE_CONV_7X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=7))
register_encoder_decoder_layer(SEPARABLE_CONV_9X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=9))
register_encoder_decoder_layer(SEPARABLE_CONV_11X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=11))
register_encoder_decoder_layer(SEPARABLE_CONV_13X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=13))
register_encoder_decoder_layer(SEPARABLE_CONV_15X1_REGISTRY_KEY,
SeparableConvLayer(conv_width=15))
register_encoder_decoder_layer(STANDARD_CONV_1X1_REGISTRY_KEY,
StandardConvLayer(conv_width=1))
register_encoder_decoder_layer(STANDARD_CONV_3X1_REGISTRY_KEY,
StandardConvLayer(conv_width=3))
register_encoder_decoder_layer(STANDARD_CONV_5X1_REGISTRY_KEY,
StandardConvLayer(conv_width=5))
register_encoder_decoder_layer(DEPTHWISE_CONV_3X1_REGISTRY_KEY,
DepthwiseConvLayer(conv_width=3))
register_encoder_decoder_layer(DEPTHWISE_CONV_5X1_REGISTRY_KEY,
DepthwiseConvLayer(conv_width=5))
register_encoder_decoder_layer(DEPTHWISE_CONV_7X1_REGISTRY_KEY,
DepthwiseConvLayer(conv_width=7))
register_encoder_decoder_layer(DILATED_CONV_3X1_REGISTRY_KEY,
DilatedConvLayer(conv_width=3))
register_encoder_decoder_layer(DILATED_CONV_5X1_REGISTRY_KEY,
DilatedConvLayer(conv_width=5))
register_encoder_decoder_layer(LIGHTWEIGHT_CONV_3X1_R_1_REGISTRY_KEY,
LightweightConvLayer(conv_width=3, num_repeat=1))
register_encoder_decoder_layer(LIGHTWEIGHT_CONV_3X1_R_4_REGISTRY_KEY,
LightweightConvLayer(conv_width=3, num_repeat=4))
register_encoder_decoder_layer(
LIGHTWEIGHT_CONV_3X1_R_16_REGISTRY_KEY,
LightweightConvLayer(conv_width=3, num_repeat=16))
register_encoder_decoder_layer(LIGHTWEIGHT_CONV_5X1_R_1_REGISTRY_KEY,
LightweightConvLayer(conv_width=5, num_repeat=1))
register_encoder_decoder_layer(LIGHTWEIGHT_CONV_5X1_R_4_REGISTRY_KEY,
LightweightConvLayer(conv_width=5, num_repeat=4))
register_encoder_decoder_layer(
LIGHTWEIGHT_CONV_5X1_R_16_REGISTRY_KEY,
LightweightConvLayer(conv_width=5, num_repeat=16))
register_encoder_decoder_layer(LIGHTWEIGHT_CONV_7X1_R_1_REGISTRY_KEY,
LightweightConvLayer(conv_width=7, num_repeat=1))
register_encoder_decoder_layer(LIGHTWEIGHT_CONV_7X1_R_4_REGISTRY_KEY,
LightweightConvLayer(conv_width=7, num_repeat=4))
register_encoder_decoder_layer(
LIGHTWEIGHT_CONV_7X1_R_16_REGISTRY_KEY,
LightweightConvLayer(conv_width=7, num_repeat=16))
register_encoder_decoder_layer(
LIGHTWEIGHT_CONV_15X1_R_1_REGISTRY_KEY,
LightweightConvLayer(conv_width=15, num_repeat=1))
register_encoder_decoder_layer(
LIGHTWEIGHT_CONV_15X1_R_4_REGISTRY_KEY,
LightweightConvLayer(conv_width=15, num_repeat=4))
register_encoder_decoder_layer(
LIGHTWEIGHT_CONV_15X1_R_16_REGISTRY_KEY,
LightweightConvLayer(conv_width=15, num_repeat=16))
register_encoder_decoder_layer(
GATED_LINEAR_UNIT_REGISTRY_KEY,
GatedLinearUnitLayer())
register_encoder_decoder_layer(
STANDARD_ATTENTION_REGISTRY_KEY,
AttentionLayer(
hidden_dim_multiplier=1, project_q=True, project_k=True,
project_v=True))
register_encoder_decoder_layer(
ATTENTION_16_HEADS_REGISTRY_KEY,
AttentionLayer(
hidden_dim_multiplier=1,
project_q=True,
project_k=True,
project_v=True,
num_heads=16))
register_encoder_decoder_layer(
ATTENTION_32_HEADS_REGISTRY_KEY,
AttentionLayer(
hidden_dim_multiplier=1,
project_q=True,
project_k=True,
project_v=True,
num_heads=32))
register_encoder_decoder_layer(
ATTENTION_4_HEADS_REGISTRY_KEY,
AttentionLayer(
hidden_dim_multiplier=1,
project_q=True,
project_k=True,
project_v=True,
num_heads=4))
================================================
FILE: tensor2tensor/models/neural_architecture_search/nas_layers_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import itertools
from absl.testing import parameterized
import numpy as np
from tensor2tensor.layers import common_attention
from tensor2tensor.models import transformer
from tensor2tensor.models.neural_architecture_search import nas_layers as layers
import tensorflow.compat.v1 as tf
_BATCH_SIZE = 32
_TOTAL_SEQUENCE_LENGTH = 20
_INPUT_DEPTH = 256
_NUM_CELLS = 6
_CELL_NUMBER = 3
# The list of prefixes for layers that will not be tested for resizing outputs.
_RESIZE_EXEMPT_LAYER_PREFIXES = [
"depthwise_conv", "squeeze_and_excitation", "identity", "lightweight_conv",
]
def _apply_encoder_layer(translation_layer, output_depth, nonpadding_list):
"""Applies an encoder layer with basic arguments."""
input_tensor = tf.random_uniform(
[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, _INPUT_DEPTH]) / 4.0
nonpadding = tf.constant(nonpadding_list)
residual_tensor = tf.random_uniform(
[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, output_depth])
hparams = transformer.transformer_base()
return translation_layer.apply_layer(
input_tensor,
residual_tensor,
output_depth,
tf.nn.relu,
hparams,
"",
mask_future=False,
nonpadding=nonpadding,
layer_preprocess_fn=None,
postprocess_dropout=True)
def _apply_decoder_layer(translation_layer, input_tensor, output_depth,
encoder_depth):
"""Applies an decoder layer with basic arguments."""
residual_tensor_values = np.random.rand(
*[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, output_depth]) - .5
residual_tensor = tf.constant(residual_tensor_values, dtype=tf.float32)
encoder_output_values = np.random.rand(
*[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, encoder_depth]) - .5
encoder_output = tf.constant(encoder_output_values, dtype=tf.float32)
encoder_cell_outputs = [encoder_output] * _NUM_CELLS
hparams = transformer.transformer_base()
hparams.attention_dropout = 0
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(_TOTAL_SEQUENCE_LENGTH))
output_tensor = translation_layer.apply_layer(
input_tensor,
residual_tensor,
output_depth,
None,
hparams,
"",
nonpadding=None,
mask_future=True,
layer_preprocess_fn=None,
postprocess_dropout=False,
decoder_self_attention_bias=decoder_self_attention_bias,
encoder_decoder_attention_bias=None,
encoder_cell_outputs=encoder_cell_outputs,
cell_number=_CELL_NUMBER)
return output_tensor
def _zero_after_index_copy(feed_input, zero_after_index):
"""Creates a copy of `feed_input` with zeros after `zero_after_index`."""
transformed_feed_input = copy.deepcopy(feed_input)
for i in range(_BATCH_SIZE):
for j in range(zero_after_index + 1, _TOTAL_SEQUENCE_LENGTH):
transformed_feed_input[i][j] = [0.0] * len(transformed_feed_input[i][j])
return transformed_feed_input
def _get_empirical_parameters():
"""Gets the number of parameters built into the current Tensorflow graph."""
trainable_variables_list = tf.trainable_variables()
empirical_num_params = 0
for variable_tensor in trainable_variables_list:
empirical_num_params += np.prod(variable_tensor.shape)
return empirical_num_params
def _create_nonpadding_list():
"""Creates the `nonpadding_list` for applying the encoder layers."""
nonpadding_list = []
for i in range(_BATCH_SIZE):
nonpadding_list.append([1.0] * min(i + 2, _TOTAL_SEQUENCE_LENGTH) +
[0.0] * max((_TOTAL_SEQUENCE_LENGTH - i - 2), 0))
return nonpadding_list
class LayersTest(parameterized.TestCase, tf.test.TestCase):
"""Tests params, residual capabilities, padding leaks, and output shape."""
# Test that the encoder registry contains all the expected layers.
def test_encoder_registry(self):
encoder_layers = [
"separable_conv_3x1",
"separable_conv_5x1",
"separable_conv_7x1",
"separable_conv_9x1",
"separable_conv_11x1",
"separable_conv_13x1",
"separable_conv_15x1",
"standard_conv_1x1",
"standard_conv_3x1",
"standard_conv_5x1",
"depthwise_conv_3x1",
"depthwise_conv_5x1",
"depthwise_conv_7x1",
"dilated_conv_3x1",
"dilated_conv_5x1",
"standard_attention",
"identity",
"attention_4_heads",
"attention_16_heads",
"attention_32_heads",
"gated_linear_unit",
"lightweight_conv_3x1_r_1",
"lightweight_conv_3x1_r_4",
"lightweight_conv_3x1_r_16",
"lightweight_conv_5x1_r_1",
"lightweight_conv_5x1_r_4",
"lightweight_conv_5x1_r_16",
"lightweight_conv_7x1_r_1",
"lightweight_conv_7x1_r_4",
"lightweight_conv_7x1_r_16",
"lightweight_conv_15x1_r_1",
"lightweight_conv_15x1_r_4",
"lightweight_conv_15x1_r_16",
]
self.assertSameElements(encoder_layers,
layers.ENCODER_LAYERS.get_layer_names())
# Test that the decoder registry contains all the expected layers.
def test_decoder_registry(self):
decoder_layers = sorted([
"separable_conv_3x1",
"separable_conv_5x1",
"separable_conv_7x1",
"separable_conv_9x1",
"separable_conv_11x1",
"separable_conv_13x1",
"separable_conv_15x1",
"standard_conv_1x1",
"standard_conv_3x1",
"standard_conv_5x1",
"depthwise_conv_3x1",
"depthwise_conv_5x1",
"depthwise_conv_7x1",
"dilated_conv_3x1",
"dilated_conv_5x1",
"standard_attention",
"attend_to_encoder",
"identity",
"attention_4_heads",
"attention_16_heads",
"attention_32_heads",
"gated_linear_unit",
"lightweight_conv_3x1_r_1",
"lightweight_conv_3x1_r_4",
"lightweight_conv_3x1_r_16",
"lightweight_conv_5x1_r_1",
"lightweight_conv_5x1_r_4",
"lightweight_conv_5x1_r_16",
"lightweight_conv_7x1_r_1",
"lightweight_conv_7x1_r_4",
"lightweight_conv_7x1_r_16",
"lightweight_conv_15x1_r_1",
"lightweight_conv_15x1_r_4",
"lightweight_conv_15x1_r_16",
])
self.assertSameElements(decoder_layers,
layers.DECODER_LAYERS.get_layer_names())
# Test encoder layer. This includes checking that output dims are as
# expected, checking that num_params() agrees with the empirical number of
# variables produced, and that information does not leak from 0 padded
# areas of the input.
@parameterized.parameters(
itertools.product(layers.ENCODER_LAYERS.get_layer_names(),
(256, 128, 512)))
def test_encoder_layer(self, translation_layer_name, output_depth):
with self.test_session(graph=tf.Graph()) as sess:
nonpadding_list = _create_nonpadding_list()
for prefix in _RESIZE_EXEMPT_LAYER_PREFIXES:
if prefix in translation_layer_name:
output_depth = _INPUT_DEPTH
translation_layer = layers.ENCODER_LAYERS.get(translation_layer_name)
output_tensor = _apply_encoder_layer(translation_layer, output_depth,
nonpadding_list)
# Check that the output shape is as expected.
self.assertEqual(output_tensor.shape.as_list(),
[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, output_depth])
# Check that the number of parameters is as expected.
empirical_num_params = _get_empirical_parameters()
reported_num_params = translation_layer.num_params(
_INPUT_DEPTH, output_depth)
self.assertEqual(empirical_num_params, reported_num_params)
# Make sure padding is applied properly (no leaks).
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
for i, j in itertools.product(
range(_BATCH_SIZE), range(_TOTAL_SEQUENCE_LENGTH)):
if nonpadding_list[i][j] == 0:
self.assertAllEqual(output[i][j], np.array([0] * output_depth),
"Output row %s, column %s not zeroed out." % (i, j))
# Test decoder layer. This includes checking that output dims are as
# expected, checking that num_params() agrees with the empirical number of
# variables produced, and that temporal information does not leak.
@parameterized.parameters(
itertools.product(layers.DECODER_LAYERS.get_layer_names(),
(256, 128, 512)))
def test_decoder_layer(self, translation_layer_name, output_depth):
with self.test_session(graph=tf.Graph()) as sess:
# Check that the output shape is as expected.
input_tensor = tf.placeholder(
tf.float32, [_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, _INPUT_DEPTH])
encoder_depth = int(_INPUT_DEPTH / 2)
for prefix in _RESIZE_EXEMPT_LAYER_PREFIXES:
if prefix in translation_layer_name:
output_depth = _INPUT_DEPTH
translation_layer = layers.DECODER_LAYERS.get(translation_layer_name)
output_tensor = _apply_decoder_layer(translation_layer, input_tensor,
output_depth, encoder_depth)
self.assertEqual(output_tensor.shape.as_list(),
[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, output_depth])
# Check that the number of parameters is as expected.
empirical_num_params = _get_empirical_parameters()
reported_num_params = translation_layer.num_params(
_INPUT_DEPTH,
output_depth,
encoder_depth=encoder_depth)
self.assertEqual(empirical_num_params, reported_num_params)
# Check that there is no temporal information leak. Specifically, check
# that values before `test_index` remain unchanged, while the values
# after it have changed. Sums are used because two values could
# potentially be the same between the zero and non-zero portions, even
# if the masking is working correctly. Note: This assumes that the
# output at t is dependent on the input at t.
feed_input = np.random.random(
[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, _INPUT_DEPTH]) / 10.0
test_index = int(_TOTAL_SEQUENCE_LENGTH / 2)
transformed_feed_input = _zero_after_index_copy(feed_input, test_index)
# Produce the outputs for both types of input.
feed_dict = {
v: np.random.rand(*v.shape.as_list()) - .5
for v in tf.all_variables()
}
feed_dict[input_tensor] = feed_input
control_output = sess.run(output_tensor, feed_dict)
feed_dict[input_tensor] = transformed_feed_input
variable_output = sess.run(output_tensor, feed_dict)
self.assertAllClose(
control_output[:, :test_index + 1],
variable_output[:, :test_index + 1],
rtol=1)
with self.assertRaises(
AssertionError,
msg="Time-masked portion of output too close to control output."):
self.assertAllClose(
control_output[:, test_index + 1:],
variable_output[:, test_index + 1:],
rtol=1)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/neural_architecture_search/nas_model.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NasSeq2Seq class which can be configured to produce a variety of models.
This was the class used in the Evolved Transformer paper
(https://arxiv.org/abs/1901.11117) to create configurable models. It can be used
to train models in the search space as was done in the paper.
To use NasSeq2Seq:
- set model=nas_seq2_seq.
- set hparams_set=nas_seq2seq_base.
- use hparams to specify the configuration you want to run. See
nas_seq2seq_base() for an example.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.models import transformer
from tensor2tensor.models.neural_architecture_search import nas_layers as layers
from tensor2tensor.utils import contrib
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
# Keys for the activation map.
LEAKY_RELU_ACTIVATION_KEY = "leaky_relu"
NONE_ACTIVATION_KEY = "none"
RELU_ACTIVATION_KEY = "relu"
SIGMOID_ACTIVATION_KEY = "sigmoid"
SWISH_ACTIVATION_KEY = "swish"
SOFTMAX_ACTIVATION_KEY = "softmax"
# Mapping from string names to activation function.
ACTIVATION_MAP = {
SWISH_ACTIVATION_KEY: tf.nn.swish,
LEAKY_RELU_ACTIVATION_KEY: tf.nn.leaky_relu,
RELU_ACTIVATION_KEY: tf.nn.relu,
NONE_ACTIVATION_KEY: None,
SIGMOID_ACTIVATION_KEY: tf.nn.sigmoid,
SOFTMAX_ACTIVATION_KEY: tf.nn.softmax
}
# Norm strings.
LAYER_NORM_KEY = "layer_norm"
NO_NORM_KEY = "none"
# Combiner function strings.
ADD_COMBINER_FUNC_KEY = "add"
MULTIPLY_COMBINER_FUNC_KEY = "multiply"
CONCAT_COMBINER_FUNC_KEY = "concat"
# Layers that force the output_dim to be equal to the input_dim if
# enforce_fixed_output_sizes is True.
LAYERS_TO_FIX_OUTPUT_SIZE = [
layers.IDENTITY_REGISTRY_KEY,
]
# Depthwise layers that the output dimension will need to be changed for
# if channel multiplier cannot be changed to match output dimension.
DEPTHWISE_LAYERS = [
layers.DEPTHWISE_CONV_3X1_REGISTRY_KEY,
layers.DEPTHWISE_CONV_5X1_REGISTRY_KEY,
layers.DEPTHWISE_CONV_7X1_REGISTRY_KEY
]
DEAD_BRANCH_KEY = "dead_branch"
def should_alter_output_dim(layer_name, enforce_fixed_output_sizes, input_depth,
output_depth):
"""Check if the output_depth for the specified layer should be changed."""
# Check to see if output_depth should be changed if we are using
# a depthwise operation and the channel multiplier is returned as 1,
# which means that the depthwise multiplier could not be set to match
# output_depth.
change_dim_for_depthwise = ((layer_name in DEPTHWISE_LAYERS) and
(layers.calculate_depthwise_channel_multiplier(
input_depth, output_depth) == 1))
# See if layer is in LAYERS_TO_FIX_OUTPUT_SIZE and if it is then we
# know that the output_dim must be input_dim.
change_dim_for_other = layer_name in LAYERS_TO_FIX_OUTPUT_SIZE
# Must be sure enforce_fixed_output_sizes is true.
return ((change_dim_for_depthwise or change_dim_for_other) and
enforce_fixed_output_sizes)
def get_activation_names():
return ACTIVATION_MAP.keys()
def _pad_shallow_tensors(tensors, pad_value):
"""Pads the shorter tensors to be as long as the longest."""
max_dim = 0
for tensor in tensors:
dim = tensor.shape.as_list()[-1]
if dim > max_dim:
max_dim = dim
output_tensors = []
for tensor in tensors:
dim = tensor.shape.as_list()[-1]
if tensor.shape.as_list()[-1] < max_dim:
output_tensors.append(
tf.pad(
tensor, [[0, 0], [0, 0], [0, max_dim - dim]],
constant_values=pad_value))
else:
output_tensors.append(tensor)
print(output_tensors)
return output_tensors
class CombinerFunction(object):
"""Interface for combiner functions."""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def combine_tensors(self, tensors):
"""Combines `tensors`.
Args:
tensors: List of tensors to combine.
Returns:
Combined tensor.
"""
@abc.abstractmethod
def combined_output_dim(self, output_dims):
"""Determines the output dimension of the combined tensor.
Args:
output_dims: List of output dimensions of combined tensors.
Returns:
Output dimension of the combined tensor.
"""
class AddCombiner(CombinerFunction):
"""Addition CombinerFunction."""
def combine_tensors(self, tensors):
assert tensors
if len(tensors) == 1:
return tensors[0]
tensors_to_combine = _pad_shallow_tensors(tensors, 0)
output_tensor = tensors_to_combine[0] + tensors_to_combine[1]
for tensor in tensors_to_combine[2:]:
output_tensor += tensor
return output_tensor
def combined_output_dim(self, output_dims):
return max(output_dims)
class MultiplyCombiner(CombinerFunction):
"""Multiply CombinerFunction."""
def combine_tensors(self, tensors):
assert tensors
if len(tensors) == 1:
return tensors[0]
tensors_to_combine = _pad_shallow_tensors(tensors, 1)
output_tensor = tensors_to_combine[0] * tensors_to_combine[1]
for tensor in tensors_to_combine[2:]:
output_tensor *= tensor
return output_tensor
def combined_output_dim(self, output_dims):
return max(output_dims)
class ConcatCombiner(CombinerFunction):
"""Concat CombinerFunction."""
def combine_tensors(self, tensors):
assert tensors
if len(tensors) == 1:
return tensors[0]
return tf.concat(tensors, 2)
def combined_output_dim(self, output_dims):
concat_tensor_dim = 0
for output_dim in output_dims:
concat_tensor_dim += output_dim
return concat_tensor_dim
# Dict of combiner functions where each key is the function key string and each
# value is a function that takes a list of tensors and outputs the tensors'
# combination.
COMBINER_FUNCTIONS = {
ADD_COMBINER_FUNC_KEY: AddCombiner,
MULTIPLY_COMBINER_FUNC_KEY: MultiplyCombiner,
CONCAT_COMBINER_FUNC_KEY: ConcatCombiner,
}
@registry.register_model
class NasSeq2Seq(transformer.Transformer):
"""Configurable seq2seq model used for Neural Architecture Search.
Models are defined by 26 hparam fields. They are:
- _num_cells: The number of cells in the .
- __layers: List of layers used the
branch. For available layers, see
the nas_layers.py file.
- _: List of inputs to the
layers. Each index i specifies the
i_th layer's output with 0
representing the cell input
tensor.
- __output_dims: List of absolute output
dimensions for each layer.
- __activation: List of activations applied
after each layer.
ACTIVATION_MAP holds the valid
activations.
- __norms: List of norms applied before each
layer. Must be either "layer_norm"
or "none".
- _combiner_functions: List of functions used to combine
each left/right branch pair.
Options are listed in
COMBINER_FUNCTIONS.
- _final_combiner_function: Function applied to combine
all the block outputs that are
not used as inputs to other
blocks. Options are listed in
COMBINER_FUNCTIONS.
For an example of how to set these hparams, please see nas_seq2seq_base().
"""
__metaclass__ = abc.ABCMeta
def encode(self, inputs, target_space, hparams, features=None, losses=None):
"""Encode inputs using _encoder().
This performs the same way as transformer.Transformer.encode with the
encoder portion replaced with _encoder().
Args:
inputs: Input [batch_size, input_length, input_height, hidden_dim] tensor
which will be flattened along the two spatial dimensions.
target_space: scalar, target space ID.
hparams: Hyperparmeters for model.
features: Optionally pass the entire features dictionary as well. This is
needed now for "packed" datasets.
losses: Unused list of losses.
Returns:
Tuple of:
encoder_output: Encoder representation.
[batch_size, input_length, hidden_dim]
encoder_decoder_attention_bias: Bias and mask weights for
encodre-decoder attention. [batch_size, input_length]
Raises:
ValueError: If encoder type not found.
"""
inputs = common_layers.flatten4d3d(inputs)
encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
transformer.transformer_prepare_encoder(
inputs, target_space, hparams, features=features))
encoder_input = tf.nn.dropout(encoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
encoder_output = self._encoder(
encoder_input,
self_attention_bias,
hparams,
nonpadding=transformer.features_to_nonpadding(features, "inputs"),
save_weights_to=self.attention_weights)
return encoder_output, encoder_decoder_attention_bias
def decode(self,
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
cache=None,
nonpadding=None,
losses=None):
"""Decode inputs using _decoder().
This performs the same way as transformer.Transformer.decode with the
decoder portion replaced with _decoder().
Args:
decoder_input: Inputs to bottom of the model. [batch_size, decoder_length,
hidden_dim]
encoder_output: Encoder representation. [batch_size, input_length,
hidden_dim]
encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder
attention. [batch_size, input_length]
decoder_self_attention_bias: Bias and mask weights for decoder
self-attention. [batch_size, decoder_length]
hparams: Hyperparmeters for model.
cache: Dict, containing tensors which are the results of previous
attentions, used for fast decoding.
nonpadding: Optional Tensor with shape [batch_size, decoder_length]
losses: Unused losses.
Returns:
Final decoder representation. [batch_size, decoder_length, hidden_dim]
"""
decoder_input = tf.nn.dropout(decoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
decoder_output = self._decoder(
decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=cache,
nonpadding=nonpadding,
save_weights_to=self.attention_weights)
if (common_layers.is_xla_compiled() and
hparams.mode == tf_estimator.ModeKeys.TRAIN):
# TPU does not react kindly to extra dimensions.
return decoder_output
# Expand since t2t expects 4d tensors.
return tf.expand_dims(decoder_output, axis=2)
def _encoder(self,
encoder_input,
encoder_self_attention_bias,
hparams,
nonpadding=None,
save_weights_to=None):
encoder_output, encoder_cell_outputs = nas_encoder(
encoder_input, encoder_self_attention_bias, hparams, nonpadding)
self._encoder_cell_outputs = encoder_cell_outputs
return encoder_output
def _decoder(self,
decoder_input,
encoder_output,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
cache=None,
nonpadding=None,
save_weights_to=None):
assert self._encoder_cell_outputs
return nas_decoder(decoder_input, self._encoder_cell_outputs,
decoder_self_attention_bias,
encoder_decoder_attention_bias, hparams)
def estimator_spec_eval(self, features, logits, labels, loss, losses_dict):
"""Construct EstimatorSpec for EVAL mode."""
if self.hparams.use_tpu:
return self._tpu_estimator_spec_eval(features, logits, labels, loss,
losses_dict)
return self._gpu_estimator_spec_eval(features, logits, labels, loss,
losses_dict)
# This function is overridden because py_func is not supported on distributed
# training, which is necessary for NAS. This function works
# the exact same way as the original Transformer.estimator_spec_eval(),
# except only neg log perplexity is accepted as a metric.
def _gpu_estimator_spec_eval(self, features, logits, labels, loss,
losses_dict):
"""Construct EstimatorSpec for GPU EVAL mode."""
hparams = self.hparams
if not hasattr(hparams, "problem"):
raise NotImplementedError(
"hparams is missing attribute `problem`. NasSeq2Seq must "
"be used with a problem.")
# TPU is not supported.
eval_metrics_fns = metrics.create_evaluation_metrics([hparams.problem],
hparams)
eval_metrics = {}
for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
if "rouge" not in metric_name and "bleu" not in metric_name:
eval_metrics[metric_name] = metric_fn(logits, features,
features["targets"])
return tf_estimator.EstimatorSpec(
tf_estimator.ModeKeys.EVAL,
predictions={"predictions": logits},
eval_metric_ops=eval_metrics,
loss=loss)
def _tpu_estimator_spec_eval(self, features, logits, labels, loss,
losses_dict):
"""Construct EstimatorSpec for TPU EVAL mode."""
del losses_dict
hparams = self.hparams
if not hasattr(hparams, "problem"):
raise NotImplementedError(
"hparams is missing attribute `problem`. NasSeq2Seq must "
"be used with a problem.")
problem = hparams.problem
t2t_model.remove_summaries()
eval_metrics_fn = t2t_model.create_tpu_eval_metrics_fn(problem, hparams)
if isinstance(logits, dict):
# For TPU, logits dict will be passed as keyword arguments to
# eval_metrics_fn. Here we add the labels to those arguments.
logits.update({"labels": labels})
return contrib.tpu().TPUEstimatorSpec(
tf_estimator.ModeKeys.EVAL,
eval_metrics=(eval_metrics_fn, logits),
loss=loss)
else:
return contrib.tpu().TPUEstimatorSpec(
tf_estimator.ModeKeys.EVAL,
eval_metrics=(eval_metrics_fn, [logits, labels]),
loss=loss)
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha,
use_tpu):
"""Forced slow beam decode.
Args:
features: an map of string to `Tensor`.
decode_length: an integer. How many additional timesteps to decode.
beam_size: number of beams.
top_beams: an integer. How many of the beams to return.
alpha: Float that controls the length penalty. larger the alpha, stronger
the preference for longer translations.
use_tpu: Whether or not TPU is being used.
Returns:
A dict of decoding results {
"outputs": integer `Tensor` of decoded ids of shape
[batch_size, <= decode_length] if beam_size == 1 or
[batch_size, top_beams, <= decode_length].
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1).
}
"""
return self._beam_decode_slow(features, decode_length, beam_size, top_beams,
alpha, use_tpu)
def _apply_layer_norm(input_tensor, nonpadding, hparams):
"""Applies Tensor2Tensor layer_norm to |input_tensor|."""
input_depth = input_tensor.shape.as_list()[-1]
if nonpadding is not None:
nonpadding_input_tiled = tf.tile(
tf.expand_dims(nonpadding, 2), [1, 1, input_depth])
output_tensor = input_tensor * nonpadding_input_tiled
output_tensor = common_layers.layer_preprocess(input_tensor, hparams)
if nonpadding is not None:
output_tensor *= nonpadding_input_tiled
return output_tensor
def _apply_nas_branch(norm, layer_norm_dict, hidden_states, nonpadding, hparams,
input_index, layer_name, activation_name, layer_registry,
output_dim, branch_scope_name, mask_future,
dropout_broadcast_dims, encoder_decoder_attention_bias,
encoder_cell_outputs, decoder_self_attention_bias,
cell_number):
"""Applies a single NAS branch."""
with tf.variable_scope(branch_scope_name):
# Apply layer norm to an individual layer at most one time.
if norm == LAYER_NORM_KEY:
try:
output_tensor = layer_norm_dict[input_index]
except KeyError:
output_tensor = _apply_layer_norm(hidden_states[input_index],
nonpadding, hparams)
layer_norm_dict[input_index] = output_tensor
elif norm == NO_NORM_KEY:
output_tensor = hidden_states[input_index]
else:
raise ValueError("norm must be either '%s' or '%s'. Got %s" %
(LAYER_NORM_KEY, NO_NORM_KEY, norm))
layer_class = layer_registry.get(layer_name)
activation = ACTIVATION_MAP[activation_name]
postprocess_dropout = layer_name != layers.IDENTITY_REGISTRY_KEY
output_tensor = layer_class.apply_layer(
output_tensor,
None,
int(output_dim),
activation,
hparams,
branch_scope_name,
mask_future=mask_future,
layer_preprocess_fn=None,
postprocess_dropout=postprocess_dropout,
nonpadding=nonpadding,
attention_dropout_broadcast_dims=dropout_broadcast_dims,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
encoder_cell_outputs=encoder_cell_outputs,
cell_number=cell_number,
decoder_self_attention_bias=decoder_self_attention_bias)
return output_tensor
def apply_nas_layers(input_tensor,
left_inputs,
left_layers,
left_activations,
left_output_dims,
left_norms,
right_inputs,
right_layers,
right_activations,
right_output_dims,
right_norms,
combiner_functions,
final_combiner_function,
num_cells,
nonpadding,
layer_registry,
mask_future,
hparams,
var_scope,
encoder_decoder_attention_bias=None,
encoder_cell_outputs=None,
decoder_self_attention_bias=None,
final_layer_norm=True,
enforce_fixed_output_sizes=True):
"""Applies layers with NasNet search space style branching.
Args:
input_tensor: Input [batch_size, input_length, hidden_dim] sequence tensor.
left_inputs: Int list of left branch hidden layer input indexes.
left_layers: String list of left branch layers.
left_activations: String list of left branch activations.
left_output_dims: String list of left branch output dimensions.
left_norms: String list of left branch norms.
right_inputs: Int list of right branch hidden layer input indexes.
right_layers: String list of right branch layers.
right_activations: String list of right branch activations.
right_output_dims: String list of right branch output dimensions.
right_norms: String list of right branch norms.
combiner_functions: String list of branch combining functions.
final_combiner_function: String. The final combiner function that combines
all the unused hidden layers in a cell.
num_cells: The number of cells. This is the number of times the given
layers will be repeated.
nonpadding: Tensor with 1s at all nonpadding time step positions and 0s
everywhere else.
layer_registry: The LayerRegistry that holds all valid layers.
mask_future: Whether or not to mask future sequence values.
hparams: Hyperparameters for the model.
var_scope: The variable scope name.
encoder_decoder_attention_bias: The attention bias for decoder attending to
`encoder_output`.
encoder_cell_outputs: List of tensors. The encoder cell outputs, listed in
order.
decoder_self_attention_bias: The self attention bias for decoders. This
needs to be set for decoders.
final_layer_norm: Whether or not to apply a final layer_norm to the output
of the model.
enforce_fixed_output_sizes: Whether or not to automatically resize output
dimensions to match the input dimension if `should_alter_output_dim()`
returns True.
Raises:
ValueError: When branching inputs are not of the same length.
ValueError: If item in left_norms is not LAYER_NORM_KEY or NO_NORM_KEY.
ValueError: If item in right_norms is not LAYER_NORM_KEY or NO_NORM_KEY.
Returns:
Output of applied layers and list of each cell's outputs in order.
"""
if not (len(left_inputs) == len(left_layers) == len(left_activations) ==
len(left_output_dims) == len(left_norms) == len(right_inputs) ==
len(right_layers) == len(right_activations) == len(right_output_dims)
== len(right_norms) == len(combiner_functions)):
raise ValueError("All branching inputs must be of the same length.")
cell_output = None
modified_left_inputs = [
left_inputs[i]
for i in range(len(left_inputs))
if left_layers[i] != DEAD_BRANCH_KEY
]
modified_right_inputs = [
right_inputs[i]
for i in range(len(right_inputs))
if right_layers[i] != DEAD_BRANCH_KEY
]
unused_cell_hidden_states = [
i for i in range(len(left_inputs) + 1)
if i not in modified_left_inputs and i not in modified_right_inputs
]
assert unused_cell_hidden_states
cell_outputs = []
with tf.variable_scope(var_scope):
dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
for cell_num in range(num_cells):
# h_0 is the input tensor.
# Keep a dict for layer norm states.
if cell_output is not None:
cell_hidden_states = [cell_output]
else:
cell_hidden_states = [input_tensor]
layer_norm_dict = {}
with tf.variable_scope("cell_%d" % cell_num):
for i, (left_input, left_layer_name, left_activation_name,
left_output_dim, left_norm, right_input, right_layer_name,
right_activation_name, right_output_dim, right_norm,
combiner) in enumerate(
zip(left_inputs, left_layers, left_activations,
left_output_dims, left_norms, right_inputs,
right_layers, right_activations, right_output_dims,
right_norms, combiner_functions)):
left_input = int(left_input)
right_input = int(right_input)
with tf.variable_scope("layer_%d" % i):
assert not (left_layer_name == DEAD_BRANCH_KEY and
right_layer_name == DEAD_BRANCH_KEY)
if left_layer_name != DEAD_BRANCH_KEY:
left_raw_input_tensor = cell_hidden_states[left_input]
left_input_dim = left_raw_input_tensor.shape.as_list()[-1]
if should_alter_output_dim(left_layer_name,
enforce_fixed_output_sizes,
left_input_dim, left_output_dim):
left_output_dim = left_input_dim
# First process the left branch.
left_tensor = _apply_nas_branch(
norm=left_norm,
layer_norm_dict=layer_norm_dict,
hidden_states=cell_hidden_states,
nonpadding=nonpadding,
hparams=hparams,
input_index=left_input,
layer_name=left_layer_name,
activation_name=left_activation_name,
layer_registry=layer_registry,
output_dim=left_output_dim,
branch_scope_name="left_%s" % str(i),
mask_future=mask_future,
dropout_broadcast_dims=dropout_broadcast_dims,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
encoder_cell_outputs=encoder_cell_outputs,
decoder_self_attention_bias=decoder_self_attention_bias,
cell_number=cell_num)
if right_layer_name != DEAD_BRANCH_KEY:
right_raw_input_tensor = cell_hidden_states[right_input]
right_input_dim = right_raw_input_tensor.shape.as_list()[-1]
if should_alter_output_dim(right_layer_name,
enforce_fixed_output_sizes,
right_input_dim, right_output_dim):
right_output_dim = right_input_dim
# Next process the right branch.
right_tensor = _apply_nas_branch(
norm=right_norm,
layer_norm_dict=layer_norm_dict,
hidden_states=cell_hidden_states,
nonpadding=nonpadding,
hparams=hparams,
input_index=right_input,
layer_name=right_layer_name,
activation_name=right_activation_name,
layer_registry=layer_registry,
output_dim=right_output_dim,
branch_scope_name="right_%s" % str(i),
mask_future=mask_future,
dropout_broadcast_dims=dropout_broadcast_dims,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
encoder_cell_outputs=encoder_cell_outputs,
decoder_self_attention_bias=decoder_self_attention_bias,
cell_number=cell_num)
# Combine the branches.
if left_layer_name == DEAD_BRANCH_KEY:
hidden_tensor = right_tensor
elif right_layer_name == DEAD_BRANCH_KEY:
hidden_tensor = left_tensor
else:
hidden_tensor = COMBINER_FUNCTIONS[combiner]().combine_tensors(
[left_tensor, right_tensor])
cell_hidden_states.append(hidden_tensor)
states_to_combine = [
cell_hidden_states[j] for j in unused_cell_hidden_states
]
cell_output = COMBINER_FUNCTIONS[final_combiner_function](
).combine_tensors(states_to_combine)
cell_outputs.append(cell_output)
if final_layer_norm:
final_output = common_layers.layer_preprocess(cell_output, hparams)
cell_outputs = [
common_layers.layer_preprocess(cell_output, hparams)
for cell_output in cell_outputs
]
return final_output, cell_outputs
else:
return cell_output, cell_outputs
def nas_encoder(encoder_input,
encoder_self_attention_bias,
hparams,
nonpadding=None,
final_layer_norm=True):
"""Encoder for configurable NAS model.
Args:
encoder_input: Input tensor.
encoder_self_attention_bias: Attention bias tensor with 0s for all valid
postions and large negative numbers for the padding positions.
hparams: transformer.Transformer hparams that must also contain:
+ encoder__inputs: List of ints specifying the hidden layer
input indexes for the branches.
+ encoder__layers: String list of layers. Each string must be
the name of a TranslationLayer registered in layers.py's ENCODER_LAYERS.
+ encoder__activations: String list of activations. Each
string in this list must have a corresponding activation in
ACTIVATION_MAP.
+ encoder__output_dims: Int list of output dimensions for
branch layers.
+ encoder__norms: String list of norms to apply to the
layer branches. Each item must be either LAYER_NORM_KEY or
NO_NORM_KEY.
+ encoder_num_cells: The number of cells in the encoder. This determines
how many times the given layers will be repeated.
+ encoder_combiner_functions: String list of functions used to combine
left and right branches. Must be a COMBINER_FUNCTION key.
nonpadding: Tensor with 1s at all nonpadding positions and 0s everywhere
else. If None (default), then nonpadding will be determined from
encoder_self_attention_bias.
final_layer_norm: Whether or not to apply a final layer_norm to the output
of the encoder.
Returns:
Encoder output and list of each encoder cell's output in order.
"""
if nonpadding is None:
padding = common_attention.attention_bias_to_padding(
encoder_self_attention_bias)
nonpadding = 1.0 - padding
return apply_nas_layers(
input_tensor=encoder_input,
left_inputs=hparams.encoder_left_inputs,
left_layers=hparams.encoder_left_layers,
left_activations=hparams.encoder_left_activations,
left_output_dims=hparams.encoder_left_output_dims,
left_norms=hparams.encoder_left_norms,
right_inputs=hparams.encoder_right_inputs,
right_layers=hparams.encoder_right_layers,
right_activations=hparams.encoder_right_activations,
right_output_dims=hparams.encoder_right_output_dims,
right_norms=hparams.encoder_right_norms,
num_cells=hparams.encoder_num_cells,
combiner_functions=hparams.encoder_combiner_functions,
final_combiner_function=hparams.encoder_final_combiner_function,
nonpadding=nonpadding,
layer_registry=layers.ENCODER_LAYERS,
mask_future=False,
hparams=hparams,
var_scope="encoder",
final_layer_norm=final_layer_norm)
def nas_decoder(decoder_input,
encoder_cell_outputs,
decoder_self_attention_bias,
encoder_decoder_attention_bias,
hparams,
final_layer_norm=True):
"""Decoder for configurable model.
Args:
decoder_input: Input tensor.
encoder_cell_outputs: List of tensors. The encoder cell outputs, listed in
order.
decoder_self_attention_bias: Attention bias that the decoder uses when
attending to itself. This should have 0s for all valid positions and large
negative numbers for all hidden future positions.
encoder_decoder_attention_bias: Attention bias that the decoder uses when
attending to the encoder. This should be 0s at all valid positions and
large negative numbers for all padded positions.
hparams: transformer.Transformer hparams that must also contain:
+ decoder__inputs: List of ints specifying the hidden layer
input indexes for the branches.
+ decoder__layers: String list of layers. Each string must be
the name of a TranslationLayer registered in layers.py's DECODER_LAYERS.
+ decoder__activations: String list of activations. Each
string in this list must have a corresponding activation in
ACTIVATION_MAP.
+ decoder__output_dims: Int list of output dimensions for
branch layers.
+ decoder__norms: String list of norms to apply to the
layer branches. Each item must be either LAYER_NORM_KEY or
NO_NORM_KEY.
+ decoder_num_cells: The number of cells in the decoder. This determines
how many times the given layers will be repeated.
+ decoder_combiner_functions: String list of functions used to combine
left and right branches. Must be a COMBINER_FUNCTION key.
hparams may also optionally contain:
+ enforce_output_size: Boolean that determines whether or not the decoder
output must be resized to hparams.hidden_size. If True, the output will
be resized if it not equal to hparams.hidden_size. If False, the output
will not be resized. If this field is not set, behavior defaults to
True.
final_layer_norm: Whether or not to apply a final layer norm to the output
of the decoder.
Returns:
Decoder output tensor.
"""
# Enforce that the output tensor depth is equal to the depth of the encoding.
(_, output_depth, _, _) = calculate_branching_model_parameters(
encoding_depth=hparams.hidden_size,
left_inputs=hparams.decoder_left_inputs,
left_layers=hparams.decoder_left_layers,
left_output_dims=hparams.decoder_left_output_dims,
right_inputs=hparams.decoder_right_inputs,
right_layers=hparams.decoder_right_layers,
right_output_dims=hparams.decoder_right_output_dims,
combiner_functions=hparams.decoder_combiner_functions,
final_combiner_function=hparams.decoder_final_combiner_function,
layer_registry=layers.DECODER_LAYERS,
num_cells=hparams.decoder_num_cells,
encoder_depth=hparams.hidden_size)
improper_output_size = output_depth != hparams.hidden_size
try:
enforce_output_size = hparams.enforce_output_size
except AttributeError:
enforce_output_size = True
resize_output = enforce_output_size and improper_output_size
decoder_cells_output, _ = apply_nas_layers(
input_tensor=decoder_input,
left_inputs=hparams.decoder_left_inputs,
left_layers=hparams.decoder_left_layers,
left_activations=hparams.decoder_left_activations,
left_output_dims=hparams.decoder_left_output_dims,
left_norms=hparams.decoder_left_norms,
right_inputs=hparams.decoder_right_inputs,
right_layers=hparams.decoder_right_layers,
right_activations=hparams.decoder_right_activations,
right_output_dims=hparams.decoder_right_output_dims,
right_norms=hparams.decoder_right_norms,
num_cells=hparams.decoder_num_cells,
combiner_functions=hparams.decoder_combiner_functions,
final_combiner_function=hparams.decoder_final_combiner_function,
nonpadding=None,
layer_registry=layers.DECODER_LAYERS,
mask_future=True,
hparams=hparams,
var_scope="decoder",
decoder_self_attention_bias=decoder_self_attention_bias,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
encoder_cell_outputs=encoder_cell_outputs,
final_layer_norm=final_layer_norm)
if not resize_output:
return decoder_cells_output
# Resize output if necessary.
dense_layer = layers.DECODER_LAYERS.get(layers.STANDARD_CONV_1X1_REGISTRY_KEY)
output = dense_layer.apply_layer(
decoder_cells_output,
None,
hparams.hidden_size,
None,
hparams,
"decoder_resize_dense",
mask_future=True,
layer_preprocess_fn=None,
postprocess_dropout=True,
nonpadding=None,
attention_dropout_broadcast_dims=None,
encoder_decoder_attention_bias=None,
encoder_cell_outputs=None,
decoder_self_attention_bias=None,
)
if final_layer_norm:
output = common_layers.layer_preprocess(output, hparams)
return output
def calculate_branching_model_parameters(encoding_depth,
left_inputs,
left_layers,
left_output_dims,
right_inputs,
right_layers,
right_output_dims,
combiner_functions,
layer_registry,
num_cells,
final_combiner_function,
encoder_depth=None,
enforce_output_size=False,
enforce_fixed_output_sizes=True):
"""Calculates the number of parameters in the given model portion.
Args:
encoding_depth: Integer. The depth of the initial input tensor.
left_inputs: Integer list. The indexes of the hidden layer inputs for the
left branch.
left_layers: String list. The names of the left branch layers.
left_output_dims: Integer list. The output dimensions for each of the left
branch layers.
right_inputs: Integer list. The indexes of the hidden layer inputs for the
right branch.
right_layers: String list. The names of the right branch layers.
right_output_dims: Integer list. The output dimensions of each of the right
branch layers.
combiner_functions: String list. The functions used to combine the left and
right branch tensors.
layer_registry: layers.LayerRegistry. The LayerRegistry that contains the
layers.TranslationLayers needed to construct the model.
num_cells: Integer. The number of times the given layers are repeated to
produce the model.
final_combiner_function: String. The COMBINER_FUNCTIONS key for the combiner
used to combine the unused hidden dimensions.
encoder_depth: Integer. The depth of the final encoder layer.
enforce_output_size: Boolean. If True, include parameters for the addition
of a dense layer that projects the final output to the appropriate
`encoding_depth` if it is not already that size. If False, do not add any
additional parameters.
enforce_fixed_output_sizes: Whether or not to automatically resize output
dimensions to match the input dimension if `should_alter_output_dim()`
returns True.
Raises:
ValueError: When the layer config lists are not of equal length.
Returns:
total_parameters: The total number of parameters in the model, accounting
for repeated cells.
output_depth: The depth of the cell output tensor.
hidden_depths: The depths of the hidden layers.
unused_outputs: List of integer indexes of the hidden layers that are not
used as input, and therefore are concatenated to produce the cell
output.
"""
if not (len(left_inputs) == len(left_layers) == len(left_output_dims) ==
len(right_inputs) == len(right_layers) == len(right_output_dims) ==
len(combiner_functions)):
raise ValueError("Layer configs must be of equal length.")
total_parameters = 0
output_depth = encoding_depth
for _ in range(num_cells):
hidden_depths = [output_depth]
unused_outputs = set(range(len(left_inputs) + 1))
for (left_input, left_layer, left_output_dim, right_input,
right_layer, right_output_dim, combiner_function) in zip(
left_inputs, left_layers, left_output_dims, right_inputs,
right_layers, right_output_dims, combiner_functions):
assert not (left_layer == DEAD_BRANCH_KEY and
right_layer == DEAD_BRANCH_KEY)
if left_layer == DEAD_BRANCH_KEY:
left_parameters = 0
else:
left_input_dim = hidden_depths[left_input]
if should_alter_output_dim(left_layer, enforce_fixed_output_sizes,
left_input_dim, left_output_dim):
left_output_dim = left_input_dim
left_parameters = layer_registry.get(left_layer).num_params(
left_input_dim, left_output_dim, encoder_depth=encoder_depth)
if right_layer == DEAD_BRANCH_KEY:
right_parameters = 0
else:
right_input_dim = hidden_depths[right_input]
if should_alter_output_dim(right_layer, enforce_fixed_output_sizes,
right_input_dim, right_output_dim):
right_output_dim = right_input_dim
right_parameters = layer_registry.get(right_layer).num_params(
right_input_dim, right_output_dim, encoder_depth=encoder_depth)
total_parameters += left_parameters + right_parameters
if left_layer == DEAD_BRANCH_KEY:
hidden_dim = right_output_dim
elif right_layer == DEAD_BRANCH_KEY:
hidden_dim = left_output_dim
else:
hidden_dim = COMBINER_FUNCTIONS[combiner_function](
).combined_output_dim([left_output_dim, right_output_dim])
hidden_depths.append(hidden_dim)
try:
if left_layer != DEAD_BRANCH_KEY:
unused_outputs.remove(left_input)
except KeyError:
pass
try:
if right_layer != DEAD_BRANCH_KEY:
unused_outputs.remove(right_input)
except KeyError:
pass
# All unused outputs combined_together.
unused_hidden_depths = [hidden_depths[index] for index in unused_outputs]
output_depth = COMBINER_FUNCTIONS[final_combiner_function](
).combined_output_dim(unused_hidden_depths)
# Add the resizing layer if needed.
if output_depth != encoding_depth and enforce_output_size:
total_parameters += layer_registry.get(
layers.STANDARD_CONV_1X1_REGISTRY_KEY).num_params(
output_depth, encoding_depth, encoder_depth=encoder_depth)
return (total_parameters, output_depth, hidden_depths, unused_outputs)
@registry.register_hparams
def nas_seq2seq_base():
"""Base parameters for Nas Seq2Seq model.
The default parameters are set to create the Transformer.
Returns:
Hyperparameters for Nas Seq2Seq model.
"""
hparams = transformer.transformer_base()
hparams.add_hparam("encoder_num_cells", 6)
hparams.add_hparam("encoder_left_inputs", [0, 1, 2, 3])
hparams.add_hparam("encoder_left_layers", [
"standard_attention", "standard_conv_1x1", "standard_conv_1x1", "identity"
])
hparams.add_hparam("encoder_left_output_dims", [512, 2048, 512, 512])
hparams.add_hparam("encoder_left_activations",
["none", "relu", "none", "none"])
hparams.add_hparam("encoder_left_norms",
["layer_norm", "layer_norm", "none", "none"])
hparams.add_hparam("encoder_right_inputs", [0, 1, 1, 1])
hparams.add_hparam("encoder_right_layers",
["identity", "dead_branch", "identity", "dead_branch"])
hparams.add_hparam("encoder_right_activations",
["none", "none", "none", "none"])
hparams.add_hparam("encoder_right_output_dims", [512, 512, 512, 512])
hparams.add_hparam("encoder_right_norms", ["none", "none", "none", "none"])
hparams.add_hparam("encoder_combiner_functions", ["add", "add", "add", "add"])
hparams.add_hparam("encoder_final_combiner_function", "add")
hparams.add_hparam("decoder_num_cells", 6)
hparams.add_hparam("decoder_left_inputs", [0, 1, 2, 3, 4])
hparams.add_hparam("decoder_left_layers", [
"standard_attention", "attend_to_encoder", "standard_conv_1x1",
"standard_conv_1x1", "identity"
])
hparams.add_hparam("decoder_left_activations",
["none", "none", "relu", "none", "none"])
hparams.add_hparam("decoder_left_output_dims", [512, 512, 2048, 512, 512])
hparams.add_hparam("decoder_left_norms",
["layer_norm", "layer_norm", "layer_norm", "none", "none"])
hparams.add_hparam("decoder_right_inputs", [0, 1, 2, 2, 4])
hparams.add_hparam(
"decoder_right_layers",
["identity", "identity", "dead_branch", "identity", "dead_branch"])
hparams.add_hparam("decoder_right_activations",
["none", "none", "none", "none", "none"])
hparams.add_hparam("decoder_right_output_dims", [512, 512, 512, 512, 512])
hparams.add_hparam("decoder_right_norms",
["none", "none", "none", "none", "none"])
hparams.add_hparam("decoder_combiner_functions",
["add", "add", "add", "add", "add"])
hparams.add_hparam("decoder_final_combiner_function", "add")
return hparams
================================================
FILE: tensor2tensor/models/neural_architecture_search/nas_model_test.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for NasSeq2Seq."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.layers import common_attention
from tensor2tensor.models import transformer
from tensor2tensor.models.neural_architecture_search import nas_layers as layers
from tensor2tensor.models.neural_architecture_search import nas_model as translation_nas_net
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
_BATCH_SIZE = 5
_INPUT_LENGTH = 5
_TARGET_LENGTH = 6
_VOCAB_SIZE = 8
_HIDDEN_SIZE = 512
_EMBEDDING_DEPTH = _HIDDEN_SIZE
def _list_product(num_list):
"""Computes product of all elements in a list."""
product = 1
for num in num_list:
product *= num
return product
def _get_transformer_branching_encoder_config():
"""Returns config for the Transformer encoder."""
num_cells = 2
left_inputs = [0, 1, 2, 3]
left_layers = [
layers.STANDARD_ATTENTION_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY
]
left_output_dims = [512, 2048, 512, 512]
right_inputs = [0, 1, 1, 3]
right_layers = [
layers.IDENTITY_REGISTRY_KEY, translation_nas_net.DEAD_BRANCH_KEY,
layers.IDENTITY_REGISTRY_KEY, translation_nas_net.DEAD_BRANCH_KEY
]
right_output_dims = [512, 512, 512, 512]
combiner_functions = [
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY
]
dummy_activations = [translation_nas_net.NONE_ACTIVATION_KEY] * 4
dummy_norms = [translation_nas_net.NO_NORM_KEY] * 4
layer_registry = layers.ENCODER_LAYERS
is_decoder = False
final_combiner_function = translation_nas_net.CONCAT_COMBINER_FUNC_KEY
return (num_cells, left_inputs, left_layers, left_output_dims, right_inputs,
right_layers, right_output_dims, combiner_functions,
final_combiner_function, dummy_activations, dummy_norms,
layer_registry, is_decoder)
def _get_transformer_branching_decoder_config():
"""Returns config for the Transformer decoder."""
num_cells = 2
left_inputs = [0, 1, 2, 3, 4]
left_layers = [
layers.STANDARD_ATTENTION_REGISTRY_KEY,
layers.ATTEND_TO_ENCODER_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY
]
left_output_dims = [512, 512, 1024, 256, 512]
right_inputs = [0, 1, 2, 3, 2]
right_layers = [
layers.IDENTITY_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY
]
right_output_dims = [512, 512, 1024, 256, 512]
combiner_functions = [
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.CONCAT_COMBINER_FUNC_KEY,
translation_nas_net.CONCAT_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY
]
dummy_activations = [translation_nas_net.NONE_ACTIVATION_KEY] * 5
dummy_norms = [translation_nas_net.NO_NORM_KEY] * 5
layer_registry = layers.DECODER_LAYERS
is_decoder = True
final_combiner_function = translation_nas_net.CONCAT_COMBINER_FUNC_KEY
return (num_cells, left_inputs, left_layers, left_output_dims, right_inputs,
right_layers, right_output_dims, combiner_functions,
final_combiner_function, dummy_activations, dummy_norms,
layer_registry, is_decoder)
def _add_transformer_branching_hparams(hparams):
(encoder_num_cells, encoder_left_inputs, encoder_left_layers,
encoder_left_output_dims, encoder_right_inputs, encoder_right_layers,
encoder_right_output_dims, encoder_combiner_functions,
encoder_final_combiner_function, encoder_dummy_activations,
encoder_dummy_norms, _, _) = _get_transformer_branching_encoder_config()
# Transformer encoder.
hparams.add_hparam("encoder_left_inputs", encoder_left_inputs)
hparams.add_hparam("encoder_left_layers", encoder_left_layers)
hparams.add_hparam("encoder_left_activations", encoder_dummy_activations)
hparams.add_hparam("encoder_left_output_dims", encoder_left_output_dims)
hparams.add_hparam("encoder_left_norms", encoder_dummy_norms)
hparams.add_hparam("encoder_right_inputs", encoder_right_inputs)
hparams.add_hparam("encoder_right_layers", encoder_right_layers)
hparams.add_hparam("encoder_right_activations", encoder_dummy_activations)
hparams.add_hparam("encoder_right_output_dims", encoder_right_output_dims)
hparams.add_hparam("encoder_right_norms", encoder_dummy_norms)
hparams.add_hparam("encoder_combiner_functions", encoder_combiner_functions)
hparams.add_hparam("encoder_num_cells", encoder_num_cells)
hparams.add_hparam("encoder_final_combiner_function",
encoder_final_combiner_function)
(decoder_num_cells, decoder_left_inputs, decoder_left_layers,
decoder_left_output_dims, decoder_right_inputs, decoder_right_layers,
decoder_right_output_dims, decoder_combiner_functions,
decoder_final_combiner_function, decoder_dummy_activations,
decoder_dummy_norms, _, _) = _get_transformer_branching_decoder_config()
# Transformer decoder.
hparams.add_hparam("decoder_left_inputs", decoder_left_inputs)
hparams.add_hparam("decoder_left_layers", decoder_left_layers)
hparams.add_hparam("decoder_left_activations", decoder_dummy_activations)
hparams.add_hparam("decoder_left_output_dims", decoder_left_output_dims)
hparams.add_hparam("decoder_left_norms", decoder_dummy_norms)
hparams.add_hparam("decoder_right_inputs", decoder_right_inputs)
hparams.add_hparam("decoder_right_layers", decoder_right_layers)
hparams.add_hparam("decoder_right_activations", decoder_dummy_activations)
hparams.add_hparam("decoder_right_output_dims", decoder_right_output_dims)
hparams.add_hparam("decoder_right_norms", decoder_dummy_norms)
hparams.add_hparam("decoder_combiner_functions", decoder_combiner_functions)
hparams.add_hparam("decoder_num_cells", decoder_num_cells)
hparams.add_hparam("decoder_final_combiner_function",
decoder_final_combiner_function)
class NasSeq2SeqTest(parameterized.TestCase, tf.test.TestCase):
def _test_model(self, model_cls, hparams):
"""Test a Translation Nas Net model."""
tf.reset_default_graph()
hparams.filter_size = 32
hparams.num_heads = 1
hparams.layer_prepostprocess_dropout = 0.0
hparams.hidden_size = _HIDDEN_SIZE
p_hparams = problem_hparams.test_problem_hparams(_VOCAB_SIZE, _VOCAB_SIZE,
hparams)
hparams.problems = [p_hparams]
inputs = -1 + np.random.random_integers(
_VOCAB_SIZE, size=(_BATCH_SIZE, _INPUT_LENGTH, 1, 1))
targets = -1 + np.random.random_integers(
_VOCAB_SIZE, size=(_BATCH_SIZE, _TARGET_LENGTH, 1, 1))
features = {
"inputs": tf.constant(inputs, dtype=tf.int32, name="inputs"),
"targets": tf.constant(targets, dtype=tf.int32, name="targets"),
"target_space_id": tf.constant(1, dtype=tf.int32)
}
model = model_cls(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape,
(_BATCH_SIZE, _TARGET_LENGTH, 1, 1, _VOCAB_SIZE))
def _get_encoder_hparams(self):
hparams = transformer.transformer_small()
hparams.add_hparam("encoder_layer_list",
layers.ENCODER_LAYERS.get_layer_names())
hparams.add_hparam("encoder_output_dim_list", [32] + [64] *
(len(hparams.encoder_layer_list) - 2) + [32])
hparams.add_hparam("encoder_activation_list", ["none"] + ["relu"] *
(len(hparams.encoder_layer_list) - 1))
hparams.add_hparam("encoder_norm_list", ["none"] + ["layer_norm"] *
(len(hparams.encoder_layer_list) - 1))
return hparams
def test_nas_seq2seq(self):
hparams = self._get_encoder_hparams()
_add_transformer_branching_hparams(hparams)
self._test_model(translation_nas_net.NasSeq2Seq, hparams)
def _get_wrong_output_dim_decoder_hparams(self):
tf.reset_default_graph()
hparams = transformer.transformer_base()
_add_transformer_branching_hparams(hparams)
hparams.num_heads = 1
# Purposely scale up the final embedding depth.
wrong_output_size = _EMBEDDING_DEPTH + 1
hparams.decoder_left_output_dims[
-2] = hparams.decoder_left_output_dims[-2] + 1
hparams.decoder_left_output_dims[-1] = wrong_output_size
return hparams, wrong_output_size
def test_nas_decoder_resizing_output(self):
hparams, wrong_size = self._get_wrong_output_dim_decoder_hparams()
hparams.enforce_output_size = False
input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
with tf.variable_scope("wrong"):
wrong_size_decoder_output = translation_nas_net.nas_decoder(
decoder_input=input_tensor,
encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells,
decoder_self_attention_bias=decoder_self_attention_bias,
encoder_decoder_attention_bias=None,
hparams=hparams)
# Now add the correction.
hparams.enforce_output_size = True
with tf.variable_scope("correct"):
correct_size_decoder_output = translation_nas_net.nas_decoder(
decoder_input=input_tensor,
encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells,
decoder_self_attention_bias=decoder_self_attention_bias,
encoder_decoder_attention_bias=None,
hparams=hparams)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
wrong_output, correct_output = session.run(
[wrong_size_decoder_output, correct_size_decoder_output])
self.assertEqual(wrong_output.shape,
(_BATCH_SIZE, _INPUT_LENGTH, wrong_size))
self.assertEqual(correct_output.shape,
(_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH))
@parameterized.parameters([(_get_transformer_branching_encoder_config,
[512, 512, 2048, 512, 512]),
(_get_transformer_branching_decoder_config,
[512, 512, 512, 2048, 512, 512])])
def test_calculate_branching_model_parameters_transformer(
self, get_config, expected_hidden_depths):
tf.reset_default_graph()
(num_cells, left_inputs, left_layers, left_output_dims, right_inputs,
right_layers, right_output_dims, combiner_functions,
final_combiner_function, dummy_activations, dummy_norms, layer_registry,
is_decoder) = get_config()
# Get predicted number of parameters.
(predicted_num_params, output_size, hidden_depths,
_) = translation_nas_net.calculate_branching_model_parameters(
encoding_depth=_EMBEDDING_DEPTH,
left_inputs=left_inputs,
left_layers=left_layers,
left_output_dims=left_output_dims,
right_inputs=right_inputs,
right_layers=right_layers,
right_output_dims=right_output_dims,
combiner_functions=combiner_functions,
final_combiner_function=final_combiner_function,
layer_registry=layer_registry,
num_cells=num_cells,
encoder_depth=_EMBEDDING_DEPTH)
# Create model graph.
input_tensor = tf.zeros([32, _INPUT_LENGTH, _EMBEDDING_DEPTH])
hparams = transformer.transformer_small()
if is_decoder:
nonpadding = None
mask_future = True
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
encoder_cell_outputs = [input_tensor] * 6
else:
nonpadding = tf.ones([32, _INPUT_LENGTH])
mask_future = False
decoder_self_attention_bias = None
encoder_cell_outputs = None
translation_nas_net.apply_nas_layers(
input_tensor=input_tensor,
left_inputs=left_inputs,
left_layers=left_layers,
left_activations=dummy_activations,
left_output_dims=left_output_dims,
left_norms=dummy_norms,
right_inputs=right_inputs,
right_layers=right_layers,
right_activations=dummy_activations,
right_output_dims=right_output_dims,
right_norms=dummy_norms,
combiner_functions=combiner_functions,
final_combiner_function=final_combiner_function,
num_cells=num_cells,
nonpadding=nonpadding,
layer_registry=layer_registry,
mask_future=mask_future,
hparams=hparams,
var_scope="test",
encoder_decoder_attention_bias=None,
encoder_cell_outputs=encoder_cell_outputs,
decoder_self_attention_bias=decoder_self_attention_bias,
final_layer_norm=False)
# Count graph variables.
trainable_variables_list = tf.trainable_variables()
empirical_num_params = 0
for variable_tensor in trainable_variables_list:
empirical_num_params += _list_product(variable_tensor.shape.as_list())
# Compare.
self.assertEqual(empirical_num_params, predicted_num_params)
self.assertEqual(output_size, _EMBEDDING_DEPTH)
self.assertEqual(hidden_depths, expected_hidden_depths)
@parameterized.parameters([True, False])
def test_calculate_branching_model_parameters_decoder_resize(
self, enforce_output_size):
tf.reset_default_graph()
hparams, _ = self._get_wrong_output_dim_decoder_hparams()
hparams.enforce_output_size = enforce_output_size
hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5
hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5
# Get predicted number of parameters.
(predicted_num_params, _, _,
_) = translation_nas_net.calculate_branching_model_parameters(
encoding_depth=_EMBEDDING_DEPTH,
left_inputs=hparams.decoder_left_inputs,
left_layers=hparams.decoder_left_layers,
left_output_dims=hparams.decoder_left_output_dims,
right_inputs=hparams.decoder_right_inputs,
right_layers=hparams.decoder_right_layers,
right_output_dims=hparams.decoder_right_output_dims,
combiner_functions=hparams.decoder_combiner_functions,
final_combiner_function=hparams.decoder_final_combiner_function,
layer_registry=layers.DECODER_LAYERS,
num_cells=hparams.decoder_num_cells,
encoder_depth=_EMBEDDING_DEPTH,
enforce_output_size=enforce_output_size)
# Count graph variables.
input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
decoder_self_attention_bias = (
common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
_ = translation_nas_net.nas_decoder(
decoder_input=input_tensor,
encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells,
decoder_self_attention_bias=decoder_self_attention_bias,
encoder_decoder_attention_bias=None,
hparams=hparams,
final_layer_norm=False)
trainable_variables_list = tf.trainable_variables()
empirical_num_params = 0
for variable_tensor in trainable_variables_list:
empirical_num_params += _list_product(variable_tensor.shape.as_list())
self.assertEqual(empirical_num_params, predicted_num_params)
def test_calculate_branching_model_parameters_output_size_only_final(self):
left_inputs = [0, 1, 2, 3]
right_inputs = [0, 1, 2, 3]
left_output_dims = [1, 10, 100, 1000]
right_output_dims = [10000, 100000, 1000000, 10000000]
right_layers = [
layers.IDENTITY_REGISTRY_KEY, layers.STANDARD_CONV_1X1_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY
]
combiner_functions = [
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.MULTIPLY_COMBINER_FUNC_KEY,
translation_nas_net.CONCAT_COMBINER_FUNC_KEY
]
(num_cells, _, left_layers, _, _, _, _, _, final_combiner_function,
dummy_activations, dummy_norms, layer_registry,
_) = _get_transformer_branching_encoder_config()
# Get predicted number of parameters.
(_, output_size, _,
_) = translation_nas_net.calculate_branching_model_parameters(
encoding_depth=_EMBEDDING_DEPTH,
left_inputs=left_inputs,
left_layers=left_layers,
left_output_dims=left_output_dims,
right_inputs=right_inputs,
right_layers=right_layers,
right_output_dims=right_output_dims,
combiner_functions=combiner_functions,
final_combiner_function=final_combiner_function,
layer_registry=layer_registry,
num_cells=num_cells,
encoder_depth=_EMBEDDING_DEPTH,
enforce_output_size=False,
enforce_fixed_output_sizes=False)
self.assertEqual(output_size, 10001000)
def test_calculate_branching_model_parameters_output_size_last_two(self):
left_inputs = [0, 1, 2, 2]
right_inputs = [0, 1, 2, 2]
left_output_dims = [1, 10, 100, 1000]
right_output_dims = [10000, 100000, 1000000, 10000000]
right_layers = [
layers.IDENTITY_REGISTRY_KEY, layers.STANDARD_CONV_1X1_REGISTRY_KEY,
layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY
]
combiner_functions = [
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.ADD_COMBINER_FUNC_KEY,
translation_nas_net.MULTIPLY_COMBINER_FUNC_KEY,
translation_nas_net.CONCAT_COMBINER_FUNC_KEY
]
(num_cells, _, left_layers, _, _, _, _, _, final_combiner_function,
dummy_activations, dummy_norms, layer_registry,
_) = _get_transformer_branching_encoder_config()
# Get predicted number of parameters.
(_, output_size, _,
_) = translation_nas_net.calculate_branching_model_parameters(
encoding_depth=_EMBEDDING_DEPTH,
left_inputs=left_inputs,
left_layers=left_layers,
left_output_dims=left_output_dims,
right_inputs=right_inputs,
right_layers=right_layers,
right_output_dims=right_output_dims,
combiner_functions=combiner_functions,
final_combiner_function=final_combiner_function,
layer_registry=layer_registry,
num_cells=num_cells,
encoder_depth=_EMBEDDING_DEPTH,
enforce_output_size=False,
enforce_fixed_output_sizes=False)
self.assertEqual(output_size, 11001000)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: tensor2tensor/models/neural_assistant.py
================================================
# coding=utf-8
# Copyright 2023 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neural Assistant."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.models import transformer
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
@registry.register_model
class NeuralAssistant(transformer.Transformer):
"""Attention net. See file docstring."""
def __init__(self, *args, **kwargs):
super(NeuralAssistant, self).__init__(*args, **kwargs)
self.attention_weights = dict() # For visualizing attention heads.
# Loss scheduling.
hparams = self._hparams
self.triple_num = hparams.train_triple_num
def model_fn(self, features):
with tf.variable_scope(tf.get_variable_scope(), use_resource=True) as vs:
self._add_variable_scope("model_fn", vs)
transformed_features = self.bottom(features)
if self.hparams.activation_dtype == "bfloat16":
for k, v in sorted(six.iteritems(transformed_features)):
if v.dtype == tf.float32:
transformed_features[k] = tf.cast(v, tf.bfloat16)
with tf.variable_scope("body") as body_vs:
self._add_variable_scope("body", body_vs)
body_out = self.body(transformed_features)
output, losses = self._normalize_body_output(body_out)
if "training" in losses:
tf.logging.info(
"Skipping T2TModel top and loss because training loss returned from body"
)
logits = output
else:
tf.logging.warn("The loss will be computed in model_fn now.")
logits = self.top(output, features)
losses["training"] = 0.0
cur_kb_loss = losses["kb_loss"]
cur_knowledge_training_loss = losses["transe_loss"]
cur_kb_loss_weight = self._hparams.kb_loss_weight
kb_train_weight = self._hparams.kb_train_weight
cur_lm_loss_weight = 1.0 - cur_kb_loss_weight
# Finalize loss
if (self._hparams.mode != tf_estimator.ModeKeys.PREDICT and
self._hparams.mode != "attack"):
lm_loss_num, lm_loss_denom = self.loss(logits, features)
total_loss = (kb_train_weight) * cur_knowledge_training_loss + (
1 - kb_train_weight) * (
cur_kb_loss * cur_kb_loss_weight +
(lm_loss_num / lm_loss_denom) * cur_lm_loss_weight)
tf.summary.scalar("kb_loss", cur_kb_loss)
tf.summary.scalar("transe_loss", cur_knowledge_training_loss)
tf.summary.scalar("lm_loss", (lm_loss_num / lm_loss_denom))
tf.summary.scalar("cur_kb_loss_weight",
tf.reshape(cur_kb_loss_weight, []))
tf.logging.info("Loss computed " + str(total_loss))
losses = {"training": total_loss}
return logits, losses
def encode_knowledge_bottom(self, features):
tf.logging.info("Encoding knowledge " + str(self.triple_num))
# Make sure this is embeddings for triples
# [batch_size, triple_num*max_triple_length, 1, emb_dim]
fact_embedding = features["encoded_triples"]
# [batch_size, triple_num*max_triple_length, emb_dim]
fact_embedding = tf.squeeze(fact_embedding, 2)
kb_shape = common_layers.shape_list(fact_embedding)
batch_size = kb_shape[0]
embed_dim = kb_shape[2]
# [batch_size*triple_num, max_triple_length, emb_dim]
re_fact_embedding = tf.reshape(
fact_embedding, [batch_size * self.triple_num, -1, embed_dim],
name="reshape_fact_embedding")
#