[
  {
    "path": "README.md",
    "content": "# HFT-price-prediction\nA project of using machine learning model (tree-based) to predict instrument price up or down in high frequency trading.\n\n## Project Background\nA data science hands-on exercise of a high frequency trading company. \n\n## Task\nTo build a model with the given data to predict whether the trading price will go up or down in a short future. (classification problem)\n\n## Data Explanation\n### Feature Columns\n<b>timestamp</b>  str, datetime string.<br>\n<b>bid_price</b>  float, price of current bid in the market.<br>\n<b>bid_qty</b>  float, quantity currently available at the bid price.<br>\n<b>bid_price</b>  float, price of current ask in the market.<br>\n<b>ask_qty</b>  float, quantity currently available at the ask price.<br>\n<b>trade_price</b>  float, last traded price.<br>\n<b>sum_trade_1s</b>  float, sum of quantity traded over the last second.<br>\n<b>bid_advance_time</b>  float, seconds since bid price last advanced.<br>\n<b>ask_advance_time</b>  float, seconds since ask price last advanced.<br>\n<b>last_trade_time</b>  float, seconds since last trade.<br>\n### Labels\n<b>_1s_side</b> int<br>\n<b>_3s_side</b> int<br>\n<b>_5s_side</b> int<br>\nLabels indicate what is type of the first event that will happen in the next x seconds, where:<br>\n<b>0</b> -- No price change.<br>\n<b>1</b> -- Bid price decreased.<br>\n<b>2</b> -- Ask price increased.<br>\n\n## Process\n### Preprocessing\n<b>data type conversion</b>: **`preprocessing()`**<br>\n<b>data check</b>: **`check_null()`**<br>\n<b>missing value handling</b>: **`fill_null()`**,\nbased on the null check and basic logic, most of the sum_trade_1s null value happens when last_trade_time larger\nthan 1 sec (in this case sum_trade_1s should be 0). Therefore, we make an assumption that all the sum_trade_1s null\nvalue could be filled with 0. Based on such assumption, last_trade_time can be filled with last_trade_time of the\nprevious record plus a time movement if record interval is smaller than 1 sec.\n### Feature Engineering\n<b>correlation filter</b>: **`correlation_filter.filter()`**, remove columns that are highly correlated to reduce data redundancy.<br>\n<b>logical feature engineering</b>: **`feature_eng.basic_features()`**, build up some features based on trading logic.<br>\n<b>time-rolling feature engineering</b>: **`feature_eng.lag_rolling_features()`**, build up features by lagging and rolling of time-series.<br>\n### Feature Selection\n**`feature_selection.select()`**, Hybrid approach of genetic algorithm selection plus feature importance selection.<br>\n<b>genetic algorithm selection</b>: **`feature_selection.GA_features()`** <br>\n<b>feature importance selection</b>: **`feature_selection.rf_imp_features()`** <br>\n### Modelling\nEnsemble of lightGBM and random forest model.<br>\n<b>random forest</b>: **`model.random_forest()`** <br>\n<b>lightGBM</b>: **`model.lightgbm()`** <br>\n### Parameter Tuning\nBased on search space to decide whether using grid search or genetic search for lightGBM model's parameter tuning.<br>\n<b>grid search</b>: **`model.GS_tune_lgbm()`** <br>\n<b>genetic search</b>: **`model.GA_tune_lgbm()`** <br>\n## Performance\nOut-of-sample classfication accuracy is roughly 76-78%, which means its prediction of the short-term future price movement is acceptable.\n"
  },
  {
    "path": "features.txt",
    "content": "{\"keep_features\": [\"bid_ask_qty_diff_diff_lag_5\", \"up_down_rolling_std_5\", \"spread_diff_rolling_mean_20\", \"spread_diff_rolling_mean_5s\", \"bid_price_rolling_std_1s\", \"bid_advance_time_rolling_mean_1s\", \"ask_qty_diff_rolling_max_10s\", \"ask_price_diff_rolling_std_3s\", \"ask_qty_rolling_std_10s\", \"bid_ask_qty_diff_rolling_std_20\", \"ask_advance_time_lag_2\", \"bid_ask_qty_total_rolling_max_10\", \"bid_ask_qty_diff_rolling_sum_5\", \"bid_qty_rolling_min_5\", \"bid_ask_qty_diff_diff_rolling_sum_3s\", \"sum_trade_1s_rolling_std_1s\", \"spread_rolling_mean_1s\", \"trade_price_diff_rolling_sum_10\", \"ask_qty_diff_rolling_sum_10s\", \"ask_price_diff_rolling_mean_5s\", \"sum_trade_1s_diff_rolling_sum_20\", \"bid_price_lag_5\", \"sum_trade_1s_rolling_mean_5\", \"bid_ask_qty_diff_rolling_min_5\", \"bid_ask_qty_diff_diff_rolling_std_3s\", \"bid_ask_qty_total_rolling_min_5\", \"bid_advance_time_diff_lag_2\", \"trade_price_compare\", \"bid_ask_qty_diff_diff_rolling_mean_20\", \"trade_price_diff_rolling_sum_3s\", \"bid_ask_qty_diff_rolling_sum_1s\", \"bid_qty\", \"ask_advance_time_rolling_mean_5s\", \"spread_diff_rolling_std_1s\", \"trade_price_compare_diff_rolling_std_1s\", \"bid_ask_qty_diff\", \"ask_qty_lag_1\", \"ask_qty_diff_rolling_sum_1s\", \"trade_price_compare_diff_rolling_sum_5\", \"spread\", \"bid_qty_lag_1\", \"bid_ask_qty_diff_rolling_mean_10\", \"bid_qty_lag_2\", \"bid_price_lag_3\", \"ask_qty_rolling_min_3s\", \"ask_advance_time_lag_4\", \"spread_diff_rolling_std_3s\", \"bid_qty_rolling_max_20\", \"ask_qty_lag_3\", \"bid_qty_diff_lag_5\", \"bid_price_diff_rolling_sum_5s\", \"trade_price_compare_diff_lag_4\", \"bid_price_diff_lag_4\", \"bid_qty_diff_rolling_sum_1s\", \"bid_ask_qty_diff_diff_rolling_max_1s\", \"bid_advance_time_rolling_mean_3s\", \"ask_advance_time_diff_lag_1\", \"ask_qty_rolling_min_5\", \"spread_rolling_std_3s\", \"bid_advance_time_rolling_std_20\", \"ask_qty_diff_rolling_min_20\", \"sum_trade_1s_rolling_mean_10\", \"spread_diff_rolling_std_20\", \"ask_qty_rolling_mean_5\", \"bid_qty_rolling_min_10\", \"trade_price_compare_diff_lag_5\", \"bid_price_rolling_std_5\", \"trade_price_rolling_mean_10\", \"sum_trade_1s_diff_rolling_std_10\", \"bid_advance_time_diff_rolling_sum_5s\", \"ask_qty_lag_2\", \"trade_price_pos_diff_rolling_std_10s\", \"ask_advance_time_diff_rolling_mean_5\", \"ask_qty_rolling_min_10\", \"sum_trade_1s_diff_lag_5\", \"last_trade_time_diff_lag_4\", \"bid_qty_diff_rolling_std_5\", \"bid_price_diff_lag_3\", \"ask_advance_time_lag_3\", \"ask_qty_rolling_mean_20\", \"ask_qty_diff_rolling_mean_5\", \"bid_ask_qty_diff_diff_rolling_sum_10s\", \"bid_advance_time_rolling_mean_5s\", \"sum_trade_1s_lag_1\", \"bid_qty_rolling_min_3s\", \"bid_qty_rolling_max_5s\", \"sum_trade_1s_diff_lag_2\", \"bid_ask_qty_total_rolling_max_10s\", \"bid_qty_rolling_mean_10\", \"bid_advance_time_lag_1\", \"bid_ask_qty_diff_lag_1\", \"bid_ask_qty_diff_diff_rolling_min_1s\", \"bid_qty_diff_rolling_std_10s\", \"bid_price_rolling_std_5s\", \"ask_qty_diff_rolling_std_5s\", \"bid_qty_diff_rolling_max_10\", \"last_trade_time\", \"ask_qty_diff_rolling_mean_1s\", \"trade_price_pos_diff_rolling_mean_3s\", \"bid_ask_qty_total_diff_rolling_max_3s\", \"ask_qty_diff_rolling_sum_3s\", \"last_trade_time_diff_rolling_mean_5s\", \"bid_ask_qty_total_diff_rolling_max_10\", \"bid_qty_rolling_mean_5\", \"ask_qty\", \"bid_ask_qty_diff_diff_rolling_mean_5s\", \"bid_ask_qty_total_diff_rolling_sum_5\", \"bid_qty_rolling_min_20\", \"last_trade_time_diff_rolling_sum_5\", \"bid_price_rolling_mean_10s\", \"ask_advance_time_diff_rolling_mean_1s\", \"sum_trade_1s_diff\"], \"correlation_remove\": [\"ask_price\"]}"
  },
  {
    "path": "modelling_pipeline.py",
    "content": "import pandas as pd\r\nimport numpy as np\r\nimport json\r\nfrom itertools import product\r\nfrom bisect import bisect_left\r\nfrom sklearn.ensemble import RandomForestClassifier\r\nfrom sklearn.model_selection import TimeSeriesSplit\r\nfrom genetic_selection import GeneticSelectionCV\r\nfrom lightgbm import LGBMClassifier\r\nfrom evolutionary_search import EvolutionaryAlgorithmSearchCV\r\nfrom sklearn.model_selection import GridSearchCV\r\nfrom sklearn.externals import joblib\r\nfrom scipy.stats import mode\r\n\r\n\r\ndef preprocessing(data):\r\n    '''align data type and time order'''\r\n    float_list = [\r\n        'bid_price',\r\n        'bid_qty',\r\n        'ask_price',\r\n        'ask_qty',\r\n        'trade_price',\r\n        'sum_trade_1s',\r\n        'bid_advance_time',\r\n        'ask_advance_time',\r\n        'last_trade_time',\r\n    ]\r\n\r\n    data['timestamp'] = pd.to_datetime(data['timestamp'])\r\n    for i in float_list:\r\n        data[i] = data[i].astype(float)\r\n\r\n    data = data.sort_values(by='timestamp', ascending=True).reset_index(drop=True)\r\n    return data\r\n\r\n\r\ndef check_null(data):\r\n    '''check null values in dataframe'''\r\n    data = data.copy()\r\n    have_null_cols = list(data.columns[data.isnull().any()])\r\n    print('Columns with null values are {}'.format(', '.join(have_null_cols)))\r\n    for i in have_null_cols:\r\n        print('number of rows that column {} is null: {}'.format(i, data[i].isnull().sum()))\r\n        print('null percentage is {}'.format(round(data[i].isnull().sum() / data.shape[0], 2)))\r\n\r\n    stat1 = data['sum_trade_1s'][data['last_trade_time'].isnull()].notnull().sum()\r\n    stat2 = data['last_trade_time'][data['sum_trade_1s'].isnull()].notnull().sum()\r\n    stat3 = data['sum_trade_1s'][data['last_trade_time'] >= 1].isnull().sum()\r\n    stat4 = stat3 / data['sum_trade_1s'].isnull().sum()\r\n    print('number of rows sum_trade_1s is not null when last_trade_time is not: {}'.format(stat1))\r\n    print('number of rows last_trade_time is null when sum_trade_1s is not: {}'.format(stat2))\r\n    print('number of rows sum_trade_1s null at last_trade_time > 1: {}, percentage: {}'.format(stat3, round(stat4, 2)))\r\n\r\n\r\ndef fill_null(data):\r\n    '''\r\n    based on the null check and basic logic, most of the sum_trade_1s null value happens when last_trade_time larger\r\n    than 1 sec (in this case sum_trade_1s should be 0). Therefore, we make an assumption that all the sum_trade_1s null\r\n    value could be filled with 0. Based on such assumption, last_trade_time can be filled with last_trade_time of the\r\n    previous record plus a time movement if record interval is smaller than 1 sec.\r\n    '''\r\n\r\n    class last_trade_time_filler:\r\n        prev_last_trade_time = None\r\n        prev_timestamp = None\r\n\r\n        @classmethod\r\n        def fill(cls, index):\r\n            last_trade_time = data.loc[index, 'last_trade_time']\r\n            timestamp = data.loc[index, 'timestamp']\r\n\r\n            if pd.isnull(last_trade_time):\r\n                time_interval = (timestamp - cls.prev_timestamp).microseconds / (1e+6)\r\n                if time_interval <= 1:\r\n                    last_trade_time = cls.prev_last_trade_time + time_interval\r\n                else:\r\n                    last_trade_time = np.nan\r\n\r\n            cls.prev_last_trade_time = last_trade_time\r\n            cls.prev_timestamp = timestamp\r\n\r\n            return last_trade_time\r\n\r\n    data = data.copy()\r\n    data.loc[data['sum_trade_1s'].isnull(), 'sum_trade_1s'] = 0\r\n    data['last_trade_time'] = data.index.map(last_trade_time_filler.fill)\r\n    print('number of null columns is: {} now'.format(len(list(data.columns[data.isnull().any()]))))\r\n\r\n    return data\r\n\r\n\r\ndef x_y_split(data):\r\n    label_cols = ['_1s_side', '_3s_side', '_5s_side']\r\n    feature_cols = list(set(data.columns) - set(label_cols))\r\n    y = data[label_cols].copy()\r\n    x = data[feature_cols].copy()\r\n\r\n    return x, y\r\n\r\n\r\nclass correlation_filter:\r\n    remove_cols = []\r\n\r\n    @classmethod\r\n    def filter(cls, x, threshold=0.99):\r\n        x = x.copy()\r\n        index2col = {i: col for i, col in enumerate(x.columns)}\r\n        corr = np.array(x.corr())\r\n        correlated_pairs = list(zip(*np.where(np.abs(corr) >= threshold)))\r\n        to_be_delete = []\r\n        for i, j in correlated_pairs:\r\n            former = index2col[i]\r\n            latter = index2col[j]\r\n            if former != latter:\r\n                add = True\r\n                for i, del_set in enumerate(to_be_delete):\r\n                    has_intersect = ({former, latter} & del_set) != {}\r\n                    if has_intersect:\r\n                        add = False\r\n                        to_be_delete[i] = del_set | {former, latter}\r\n                if add:\r\n                    to_be_delete.append({former, latter})\r\n\r\n        for i in to_be_delete:\r\n            delete_set = i.copy()\r\n            delete_set.pop()\r\n            x = x.drop(list(delete_set), axis=1)\r\n            cls.remove_cols += list(delete_set)\r\n\r\n        return x\r\n\r\n\r\nclass feature_eng:\r\n    timestamp = None\r\n    max_lag = 5\r\n    num_window = [5, 10, 20]\r\n    sec_window = [1, 3, 5, 10]\r\n    rolling_sum_cols = []\r\n    rolling_mean_cols = []\r\n    rolling_max_cols = []\r\n    rolling_min_cols = []\r\n    rolling_std_cols = []\r\n\r\n    @staticmethod\r\n    def bid_ask_spread(data):\r\n        data['spread'] = data['ask_price'] - data['bid_price']\r\n\r\n    @staticmethod\r\n    def bid_ask_qty_comb(data):\r\n        data['bid_ask_qty_total'] = data['ask_qty'] + data['bid_qty']\r\n        data['bid_ask_qty_diff'] = data['ask_qty'] - data['bid_qty']\r\n\r\n    @staticmethod\r\n    def trade_price_feature(data):\r\n        data['trade_price_compare'] = 0  # when trade price between current bid and ask price\r\n        data.loc[data['trade_price'] <= data[\r\n            'bid_price'], 'trade_price_compare'] = -1  # when trade price on current bid side\r\n        data.loc[data['trade_price'] >= data[\r\n            'ask_price'], 'trade_price_compare'] = 1  # when trade price on current sell side\r\n\r\n        # whether trade price happens on bid side or ask side during the time it happens\r\n        last_trade_timestamp = data['timestamp'] - pd.to_timedelta(data['last_trade_time'], unit='s')\r\n        idx_list = [bisect_left(data['timestamp'], i) for i in list(last_trade_timestamp)]\r\n        trade_price_pos = []\r\n        for i, index in enumerate(idx_list):\r\n            index1 = index\r\n            index2 = index1 + 1 if index1 < data.shape[0] - 1 else index1\r\n            bid1 = data['bid_price'][index1]\r\n            bid2 = data['bid_price'][index2]\r\n            ask1 = data['ask_price'][index1]\r\n            ask2 = data['ask_price'][index2]\r\n            trade_price = data['trade_price'][i]\r\n            if (bid1 <= trade_price <= bid2) or (bid2 <= trade_price <= bid1):\r\n                trade_price_pos.append(-1)  # happen on bid side\r\n            elif (ask1 <= trade_price <= ask2) or (ask2 <= trade_price <= ask1):\r\n                trade_price_pos.append(1)  # happen on sell side\r\n            else:\r\n                trade_price_pos.append(0)  # unknown case\r\n        data['trade_price_pos'] = trade_price_pos\r\n\r\n    @staticmethod\r\n    def diff_feature(data):\r\n        for i in set(data.columns) - {'timestamp'}:\r\n            new_name = '{}_diff'.format(i)\r\n            data[new_name] = data[i] - data[i].shift(1)\r\n\r\n    @staticmethod\r\n    def up_or_down(data):\r\n        data['up_down'] = 0\r\n        data.loc[data['bid_price_diff'] < 0, 'up_down'] = -1\r\n        data.loc[data['ask_price_diff'] > 0, 'up_down'] = 1\r\n\r\n    @staticmethod\r\n    def lag_feature(data, col, lag):\r\n        new_col_name = '{}_lag_{}'.format(col, lag)\r\n        data[new_col_name] = data[col].shift(lag)\r\n\r\n    @staticmethod\r\n    def rolling_feature(data, col, window, feature):\r\n        rolling = data[col].rolling(window=window)\r\n        new_col = '{}_rolling_{}_{}'.format(col, feature, window)\r\n\r\n        if feature == 'sum':\r\n            data[new_col] = rolling.sum()\r\n        elif feature == 'mean':\r\n            data[new_col] = rolling.mean()\r\n        elif feature == 'max':\r\n            data[new_col] = rolling.max()\r\n        elif feature == 'min':\r\n            data[new_col] = rolling.min()\r\n        elif feature == 'std':\r\n            data[new_col] = rolling.std()\r\n        elif feature == 'mode':\r\n            data[new_col] = rolling.apply(lambda x: mode(x)[0])\r\n\r\n    @classmethod\r\n    def basic_features(cls, data):\r\n        data = data.copy()\r\n        cls.timestamp = data['timestamp']\r\n\r\n        cls.bid_ask_spread(data)\r\n        cls.bid_ask_qty_comb(data)\r\n        cls.trade_price_feature(data)\r\n        cls.diff_feature(data)\r\n        cls.up_or_down(data)\r\n\r\n        data = data.drop('timestamp', axis=1)\r\n        return data\r\n\r\n    @classmethod\r\n    def lag_rolling_features(cls, data):\r\n        data = data.copy()\r\n\r\n        # get lag and rolling feature based on previous n records\r\n        rolling_cols = set(data.columns) - {'trade_price_compare', 'trade_price_pos'}\r\n        cls.rolling_sum_cols = [i for i in rolling_cols if 'diff' in i or 'up_down' in i]\r\n        cls.rolling_mean_cols = rolling_cols\r\n        cls.rolling_max_cols = [i for i in rolling_cols if 'bid_qty' in i or 'ask_qty' in i]\r\n        cls.rolling_min_cols = [i for i in rolling_cols if 'bid_qty' in i or 'ask_qty' in i]\r\n        cls.rolling_std_cols = rolling_cols\r\n\r\n        for col in rolling_cols:\r\n            for lag in range(1, cls.max_lag + 1):\r\n                cls.lag_feature(data, col, lag)\r\n\r\n        for col in rolling_cols:\r\n            for num_window in cls.num_window:\r\n                if col in cls.rolling_sum_cols:\r\n                    cls.rolling_feature(data, col, num_window, 'sum')\r\n                if col in cls.rolling_mean_cols:\r\n                    cls.rolling_feature(data, col, num_window, 'mean')\r\n                if col in cls.rolling_max_cols:\r\n                    cls.rolling_feature(data, col, num_window, 'max')\r\n                if col in cls.rolling_min_cols:\r\n                    cls.rolling_feature(data, col, num_window, 'min')\r\n                if col in cls.rolling_std_cols:\r\n                    cls.rolling_feature(data, col, num_window, 'std')\r\n\r\n        # get rolling feature based on previous n seconds\r\n        data.index = cls.timestamp\r\n        for col in rolling_cols:\r\n            for sec_window in cls.sec_window:\r\n                sec_window = '{}s'.format(sec_window)\r\n                if col in cls.rolling_sum_cols:\r\n                    cls.rolling_feature(data, col, sec_window, 'sum')\r\n                if col in cls.rolling_mean_cols:\r\n                    cls.rolling_feature(data, col, sec_window, 'mean')\r\n                if col in cls.rolling_max_cols:\r\n                    cls.rolling_feature(data, col, sec_window, 'max')\r\n                if col in cls.rolling_min_cols:\r\n                    cls.rolling_feature(data, col, sec_window, 'min')\r\n                if col in cls.rolling_std_cols:\r\n                    cls.rolling_feature(data, col, sec_window, 'std')\r\n                if col in ['up_down', 'trade_price_compare', 'trade_price_pos']:\r\n                    cls.rolling_feature(data, col, sec_window, 'mode')\r\n\r\n        return data\r\n\r\n    @staticmethod\r\n    def remove_na(x, y):\r\n        x = x.reset_index(drop=True)\r\n        x = x.dropna()\r\n        y = y.loc[x.index, :].reset_index(drop=True)\r\n        x = x.reset_index(drop=True)\r\n        return x, y\r\n\r\n\r\nclass feature_selection:\r\n    '''feature selection combining feature importance ranking and GA optimization based on random forest model'''\r\n\r\n    @classmethod\r\n    def select(cls, x, y):\r\n        rf_imp_features = cls.rf_imp_features(x, y)\r\n        ga_features = cls.GA_features(x, y)\r\n        features = set(rf_imp_features) | set(ga_features)\r\n\r\n        return list(features)\r\n\r\n    @classmethod\r\n    def rf_imp_features(cls, x, y, top_perc=0.05):\r\n        '''select top features based on feature importance ranking among all the features'''\r\n        feature_imp = cls.rf_importance_selection(x, y)\r\n        perc_threshold = np.percentile(feature_imp['avg_importance'], int((1 - top_perc) * 100))\r\n        features = list(feature_imp.loc[feature_imp['avg_importance'] >= perc_threshold, 'feature'])\r\n\r\n        return features\r\n\r\n    @staticmethod\r\n    def rf_importance_selection(x, y, iter_time=3):\r\n        feature_imp = pd.DataFrame(np.zeros((x.shape[1], iter_time + 2)))\r\n        feature_imp.columns = ['feature'] + ['importance_{}'.format(i) for i in range(1, iter_time + 1)] + [\r\n            'avg_importance']\r\n        for col in feature_imp.columns:\r\n            feature_imp[col] = list(x.columns)\r\n\r\n        for i in range(1, iter_time + 1):\r\n            col = 'importance_{}'.format(i)\r\n            rf = RandomForestClassifier(n_estimators=10, max_depth=8)\r\n            rf.fit(x, y)\r\n            feature_imp_dict = dict(zip(x.columns, rf.feature_importances_))\r\n            feature_imp[col] = feature_imp[col].replace(feature_imp_dict)\r\n\r\n        feature_imp['avg_importance'] = feature_imp.iloc[:, 1:-1].mean(axis=1)\r\n        return feature_imp\r\n\r\n    @staticmethod\r\n    def GA_features(x, y):\r\n        rf = RandomForestClassifier(max_depth=8, n_estimators=10)\r\n        selector = GeneticSelectionCV(\r\n            rf,\r\n            cv=TimeSeriesSplit(n_splits=4),\r\n            verbose=1,\r\n            scoring=\"accuracy\",\r\n            max_features=80,\r\n            n_population=200,\r\n            crossover_proba=0.5,\r\n            mutation_proba=0.2,\r\n            n_generations=100,\r\n            crossover_independent_proba=0.5,\r\n            mutation_independent_proba=0.05,\r\n            tournament_size=3,\r\n            n_gen_no_change=5,\r\n            caching=True,\r\n            n_jobs=-1\r\n        )\r\n        selector = selector.fit(x, y)\r\n        features = x.columns[selector.support_]\r\n\r\n        return features\r\n\r\n\r\nclass model:\r\n    lgbm_paramgrid = {\r\n        'learning_rate': np.arange(0.0005, 0.0015, 0.0001),\r\n        'n_estimators': range(800, 2000, 200),\r\n        'max_depth': [3, 4],\r\n        'colsample_bytree': np.arange(0.2, 0.5, 0.1),\r\n        'reg_alpha': [1],\r\n        'reg_lambda': [1]\r\n    }\r\n\r\n    @staticmethod\r\n    def random_forest(x, y):\r\n        rf = RandomForestClassifier(n_estimators=200, max_depth=8)\r\n        rf.fit(x, y)\r\n        return rf\r\n\r\n    @classmethod\r\n    def lightgbm(cls, x, y):\r\n        keys, vals = list(zip(*cls.lgbm_paramgrid.items()))\r\n        products = list(product(*vals))\r\n        param_comb = [dict(zip(keys, i)) for i in products]\r\n\r\n        if len(param_comb) > 1000:\r\n            best_param = cls.GA_tune_lgbm(x, y)\r\n        else:\r\n            best_param = cls.GS_tune_lgbm(x, y)\r\n\r\n        lightgbm = LGBMClassifier(**best_param)\r\n        lightgbm.fit(x, y)\r\n\r\n        return lightgbm\r\n\r\n    @classmethod\r\n    def GA_tune_lgbm(cls, x, y):\r\n        tuner = EvolutionaryAlgorithmSearchCV(\r\n            estimator=LGBMClassifier(),\r\n            params=cls.lgbm_paramgrid,\r\n            scoring=\"accuracy\",\r\n            cv=TimeSeriesSplit(n_splits=4),\r\n            verbose=1,\r\n            population_size=50,\r\n            gene_mutation_prob=0.2,\r\n            gene_crossover_prob=0.5,\r\n            tournament_size=3,\r\n            generations_number=20,\r\n        )\r\n        tuner.fit(x, y)\r\n        return tuner.best_params_\r\n\r\n    @classmethod\r\n    def GS_tune_lgbm(cls, x, y):\r\n        tuner = GridSearchCV(\r\n            estimator=LGBMClassifier(),\r\n            param_grid=cls.lgbm_paramgrid,\r\n            scoring=\"accuracy\",\r\n            cv=TimeSeriesSplit(n_splits=4),\r\n            verbose=1,\r\n            n_jobs=-1,\r\n        )\r\n        tuner.fit(x, y)\r\n        return tuner.best_params_\r\n\r\n\r\nclass feature:\r\n    @staticmethod\r\n    def save(features, correlation_remove):\r\n        final = {\r\n            'keep_features': features,\r\n            'correlation_remove': correlation_remove\r\n        }\r\n\r\n        with open('features.txt', 'w') as f:\r\n            f.write(json.dumps(final))\r\n\r\n    @staticmethod\r\n    def load():\r\n        with open('features.txt', 'r') as f:\r\n            features = f.read()\r\n            features = json.loads(features)\r\n\r\n        return features\r\n\r\n\r\ndef train_model(data, target_label):\r\n    data = data.copy()\r\n    data = preprocessing(data)\r\n    check_null(data)\r\n    data = fill_null(data)\r\n    x, y = x_y_split(data)\r\n    x = feature_eng.basic_features(x)\r\n    x = correlation_filter.filter(x)\r\n    x = feature_eng.lag_rolling_features(x)\r\n    x, y = feature_eng.remove_na(x, y)\r\n    y = y[target_label]\r\n    features = feature_selection.select(x, y)\r\n    feature.save(features, correlation_filter.remove_cols)\r\n    lightgbm = model.lightgbm(x[features], y)\r\n    rf = model.random_forest(x[features], y)\r\n    joblib.dump(rf, 'rf.joblib')\r\n    joblib.dump(lightgbm, 'lgbm.joblib')\r\n\r\n\r\ndef predict(data, target_label):\r\n    '''returns both the prediction and the target_label'''\r\n    features = feature.load()['keep_features']\r\n    correlation_remove = feature.load()['correlation_remove']\r\n    data = data.copy()\r\n    data = preprocessing(data)\r\n    data = fill_null(data)\r\n    x, y = x_y_split(data)\r\n    x = feature_eng.basic_features(x)\r\n    x = x.drop(correlation_remove, axis=1)\r\n    x = feature_eng.lag_rolling_features(x)\r\n    x, y = feature_eng.remove_na(x, y)\r\n    y = y[target_label]\r\n    x = x[features]\r\n    lgbm = joblib.load('lgbm.joblib')\r\n    rf = joblib.load('rf.joblib')\r\n    lgbm_predict = lgbm.predict_proba(x)\r\n    rf_predict = rf.predict_proba(x)\r\n    final_predict = (lgbm_predict + rf_predict) / 2\r\n    final_predict = np.argmax(final_predict, axis=1)\r\n\r\n    return final_predict, y\r\n\r\n\r\nif __name__ == '__main__':\r\n    data = pd.read_csv('data.csv')\r\n    target_label = '_5s_side'\r\n    train_model(data, target_label)\r\n    pred, true_val = predict(data, target_label)\r\n"
  }
]