Repository: bojone/keras_lookahead Branch: master Commit: 675da97af7e7 Files: 2 Total size: 3.2 KB Directory structure: gitextract_nrd2unsr/ ├── README.md └── lookahead.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ ## Lookahead Optimizer for Keras Keras implement of [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610). Usage: ``` model.compile(optimizer=Adam(1e-3), loss='mse') # Any optimizer lookahead = Lookahead(k=5, alpha=0.5) # Initialize Lookahead lookahead.inject(model) # add into model ``` ## Lookahead优化器的Keras实现 论文[《Lookahead Optimizer: k steps forward, 1 step back》](https://arxiv.org/abs/1907.08610)的Keras实现。 用法: ``` model.compile(optimizer=Adam(1e-3), loss='mse') # 用你想用的优化器 lookahead = Lookahead(k=5, alpha=0.5) # 初始化Lookahead lookahead.inject(model) # 插入到模型中 ``` 中文介绍:[点击进入](https://mp.weixin.qq.com/s/3J-28xd0pyToSy8zzKs1RA) ## 交流 QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn ================================================ FILE: lookahead.py ================================================ #! -*- coding: utf-8 -*- from keras import backend as K class Lookahead(object): """Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/). """ def __init__(self, k=5, alpha=0.5): self.k = k self.alpha = alpha self.count = 0 def inject(self, model): """Inject the Lookahead algorithm for the given model. The following code is modified from keras's _make_train_function method. See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497 """ if not hasattr(model, 'train_function'): raise RuntimeError('You must compile your model before using it.') model._check_trainable_weights_consistency() if model.train_function is None: inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights) if model._uses_dynamic_learning_phase(): inputs += [K.learning_phase()] fast_params = model._collected_trainable_weights with K.name_scope('training'): with K.name_scope(model.optimizer.__class__.__name__): training_updates = model.optimizer.get_updates( params=fast_params, loss=model.total_loss) slow_params = [K.variable(p) for p in fast_params] fast_updates = (model.updates + training_updates + model.metrics_updates) slow_updates, copy_updates = [], [] for p, q in zip(fast_params, slow_params): slow_updates.append(K.update(q, q + self.alpha * (p - q))) copy_updates.append(K.update(p, q)) # Gets loss and metrics. Updates weights at each call. fast_train_function = K.function( inputs, [model.total_loss] + model.metrics_tensors, updates=fast_updates, name='fast_train_function', **model._function_kwargs) def F(inputs): self.count += 1 R = fast_train_function(inputs) if self.count % self.k == 0: K.batch_get_value(slow_updates) K.batch_get_value(copy_updates) return R model.train_function = F