[
  {
    "path": "README.md",
    "content": "## Lookahead Optimizer for Keras\n\nKeras implement of [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610).\n\nUsage:\n```\nmodel.compile(optimizer=Adam(1e-3), loss='mse') # Any optimizer\nlookahead = Lookahead(k=5, alpha=0.5) # Initialize Lookahead\nlookahead.inject(model) # add into model\n```\n\n## Lookahead优化器的Keras实现\n\n论文[《Lookahead Optimizer: k steps forward, 1 step back》](https://arxiv.org/abs/1907.08610)的Keras实现。\n\n用法：\n```\nmodel.compile(optimizer=Adam(1e-3), loss='mse') # 用你想用的优化器\nlookahead = Lookahead(k=5, alpha=0.5) # 初始化Lookahead\nlookahead.inject(model) # 插入到模型中\n```\n\n中文介绍：[点击进入](https://mp.weixin.qq.com/s/3J-28xd0pyToSy8zzKs1RA)\n\n## 交流\nQQ交流群：67729435，微信群请加机器人微信号spaces_ac_cn\n"
  },
  {
    "path": "lookahead.py",
    "content": "#! -*- coding: utf-8 -*-\n\nfrom keras import backend as K\n\n\nclass Lookahead(object):\n    \"\"\"Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/).\n    \"\"\"\n\n    def __init__(self, k=5, alpha=0.5):\n        self.k = k\n        self.alpha = alpha\n        self.count = 0\n\n    def inject(self, model):\n        \"\"\"Inject the Lookahead algorithm for the given model.\n        The following code is modified from keras's _make_train_function method.\n        See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497\n        \"\"\"\n        if not hasattr(model, 'train_function'):\n            raise RuntimeError('You must compile your model before using it.')\n\n        model._check_trainable_weights_consistency()\n\n        if model.train_function is None:\n            inputs = (model._feed_inputs +\n                      model._feed_targets +\n                      model._feed_sample_weights)\n            if model._uses_dynamic_learning_phase():\n                inputs += [K.learning_phase()]\n            fast_params = model._collected_trainable_weights\n\n            with K.name_scope('training'):\n                with K.name_scope(model.optimizer.__class__.__name__):\n                    training_updates = model.optimizer.get_updates(\n                        params=fast_params,\n                        loss=model.total_loss)\n                    slow_params = [K.variable(p) for p in fast_params]\n                fast_updates = (model.updates +\n                                training_updates +\n                                model.metrics_updates)\n\n                slow_updates, copy_updates = [], []\n                for p, q in zip(fast_params, slow_params):\n                    slow_updates.append(K.update(q, q + self.alpha * (p - q)))\n                    copy_updates.append(K.update(p, q))\n\n                # Gets loss and metrics. Updates weights at each call.\n                fast_train_function = K.function(\n                    inputs,\n                    [model.total_loss] + model.metrics_tensors,\n                    updates=fast_updates,\n                    name='fast_train_function',\n                    **model._function_kwargs)\n\n                def F(inputs):\n                    self.count += 1\n                    R = fast_train_function(inputs)\n                    if self.count % self.k == 0:\n                        K.batch_get_value(slow_updates)\n                        K.batch_get_value(copy_updates)\n                    return R\n                \n                model.train_function = F\n"
  }
]