Repository: bojone/keras_lookahead
Branch: master
Commit: 675da97af7e7
Files: 2
Total size: 3.2 KB
Directory structure:
gitextract_nrd2unsr/
├── README.md
└── lookahead.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
## Lookahead Optimizer for Keras
Keras implement of [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610).
Usage:
```
model.compile(optimizer=Adam(1e-3), loss='mse') # Any optimizer
lookahead = Lookahead(k=5, alpha=0.5) # Initialize Lookahead
lookahead.inject(model) # add into model
```
## Lookahead优化器的Keras实现
论文[《Lookahead Optimizer: k steps forward, 1 step back》](https://arxiv.org/abs/1907.08610)的Keras实现。
用法:
```
model.compile(optimizer=Adam(1e-3), loss='mse') # 用你想用的优化器
lookahead = Lookahead(k=5, alpha=0.5) # 初始化Lookahead
lookahead.inject(model) # 插入到模型中
```
中文介绍:[点击进入](https://mp.weixin.qq.com/s/3J-28xd0pyToSy8zzKs1RA)
## 交流
QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn
================================================
FILE: lookahead.py
================================================
#! -*- coding: utf-8 -*-
from keras import backend as K
class Lookahead(object):
"""Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/).
"""
def __init__(self, k=5, alpha=0.5):
self.k = k
self.alpha = alpha
self.count = 0
def inject(self, model):
"""Inject the Lookahead algorithm for the given model.
The following code is modified from keras's _make_train_function method.
See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497
"""
if not hasattr(model, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
model._check_trainable_weights_consistency()
if model.train_function is None:
inputs = (model._feed_inputs +
model._feed_targets +
model._feed_sample_weights)
if model._uses_dynamic_learning_phase():
inputs += [K.learning_phase()]
fast_params = model._collected_trainable_weights
with K.name_scope('training'):
with K.name_scope(model.optimizer.__class__.__name__):
training_updates = model.optimizer.get_updates(
params=fast_params,
loss=model.total_loss)
slow_params = [K.variable(p) for p in fast_params]
fast_updates = (model.updates +
training_updates +
model.metrics_updates)
slow_updates, copy_updates = [], []
for p, q in zip(fast_params, slow_params):
slow_updates.append(K.update(q, q + self.alpha * (p - q)))
copy_updates.append(K.update(p, q))
# Gets loss and metrics. Updates weights at each call.
fast_train_function = K.function(
inputs,
[model.total_loss] + model.metrics_tensors,
updates=fast_updates,
name='fast_train_function',
**model._function_kwargs)
def F(inputs):
self.count += 1
R = fast_train_function(inputs)
if self.count % self.k == 0:
K.batch_get_value(slow_updates)
K.batch_get_value(copy_updates)
return R
model.train_function = F
gitextract_nrd2unsr/ ├── README.md └── lookahead.py
SYMBOL INDEX (3 symbols across 1 files)
FILE: lookahead.py
class Lookahead (line 6) | class Lookahead(object):
method __init__ (line 10) | def __init__(self, k=5, alpha=0.5):
method inject (line 15) | def inject(self, model):
Condensed preview — 2 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4K chars).
[
{
"path": "README.md",
"chars": 715,
"preview": "## Lookahead Optimizer for Keras\n\nKeras implement of [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.o"
},
{
"path": "lookahead.py",
"chars": 2582,
"preview": "#! -*- coding: utf-8 -*-\n\nfrom keras import backend as K\n\n\nclass Lookahead(object):\n \"\"\"Add the [Lookahead Optimizer]"
}
]
About this extraction
This page contains the full source code of the bojone/keras_lookahead GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 2 files (3.2 KB), approximately 829 tokens, and a symbol index with 3 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.