Repository: SJTU-DMTai/StockMixer Branch: master Commit: cce13598afd3 Files: 23 Total size: 158.1 MB Directory structure: gitextract_q0lqk85k/ ├── .idea/ │ ├── .gitignore │ ├── StockMixer.iml │ ├── inspectionProfiles/ │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ └── vcs.xml ├── README.md ├── dataset/ │ ├── NASDAQ/ │ │ ├── eod_data.pkl │ │ ├── gt_data.pkl │ │ ├── mask_data.pkl │ │ └── price_data.pkl │ ├── NYSE/ │ │ ├── eod_data.pkl │ │ ├── gt_data.pkl │ │ ├── mask_data.pkl │ │ └── price_data.pkl │ └── SP500/ │ ├── SP500.npy │ ├── baseline_data_sp500.npy │ └── sp500_ticker.csv ├── requirements.txt └── src/ ├── evaluator.py ├── load_data.py ├── model.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .idea/.gitignore ================================================ # Default ignored files /shelf/ /workspace.xml # Editor-based HTTP Client requests /httpRequests/ # Datasource local storage ignored files /dataSources/ /dataSources.local.xml ================================================ FILE: .idea/StockMixer.iml ================================================ ================================================ FILE: .idea/inspectionProfiles/profiles_settings.xml ================================================ ================================================ FILE: .idea/misc.xml ================================================ ================================================ FILE: .idea/modules.xml ================================================ ================================================ FILE: .idea/vcs.xml ================================================ ================================================ FILE: README.md ================================================ # StockMixer Official code implementation and supplementary material of AAAI 2024 paper "**StockMixer: A Simple yet Strong MLP-based Architecture for Stock Price Forecasting**". This work proposes a lightweight and effective MLP-based architecture for stock price forecasting named StockMixer. It consists of indicator mixing, temporal mixing and stock mixing to capture complex correlations in the stock data. The end-to-end training flow of StockMixer is presented as follows: ## Environment - Python 3.7 - torch~=1.10.1 - numpy~=1.21.5 - PyYAML, pandas, tqdm, matplotlib ## Dataset and Preprocessing The original datasets(NASDAQ, NYSE and S&P500) are respectively available: NASDAQ/NYSE: [Temporal Relational Ranking for Stock Prediction](https://github.com/fulifeng/Temporal_Relational_Stock_Ranking) S&P500: [Efficient Integration of Multi-Order Dynamics and Internal Dynamics in Stock Movement Prediction](https://github.com/thanhtrunghuynh93/estimate) In order to improve file reading speed, we process the raw data to generate corresponding .pkl or .npy files. Datasets are provided in the `dataset` folder. Because StockMixer does not require prior knowledge similar to graphs or hypergraphs, our preprocessed dataset did not provide either. You can find them from the original datasets. ## Running the Code ``` # edit configurations in train.py python src/train.py ``` ================================================ FILE: dataset/NASDAQ/eod_data.pkl ================================================ [File too large to display: 24.4 MB] ================================================ FILE: dataset/NYSE/eod_data.pkl ================================================ ================================================ FILE: dataset/NYSE/gt_data.pkl ================================================ ================================================ FILE: dataset/NYSE/mask_data.pkl ================================================ ================================================ FILE: dataset/NYSE/price_data.pkl ================================================ ================================================ FILE: dataset/SP500/SP500.npy ================================================ [File too large to display: 45.7 MB] ================================================ FILE: dataset/SP500/baseline_data_sp500.npy ================================================ [File too large to display: 88.0 MB] ================================================ FILE: dataset/SP500/sp500_ticker.csv ================================================ Symbol,Name,Sector MMM,3M,Industrials AOS,A. O. Smith,Industrials ABT,Abbott Laboratories,Health Care ABBV,AbbVie,Health Care ABMD,Abiomed,Health Care ACN,Accenture,Information Technology ATVI,Activision Blizzard,Communication Services ADM,ADM,Consumer Staples ADBE,Adobe,Information Technology AAP,Advance Auto Parts,Consumer Discretionary AMD,Advanced Micro Devices,Information Technology AES,AES Corp,Utilities AFL,Aflac,Financials A,Agilent Technologies,Health Care APD,Air Products & Chemicals,Materials AKAM,Akamai Technologies,Information Technology ALK,Alaska Air Group,Industrials ALB,Albemarle Corporation,Materials ARE,Alexandria Real Estate Equities,Real Estate ALGN,Align Technology,Health Care ALLE,Allegion,Industrials LNT,Alliant Energy,Utilities ALL,Allstate Corp,Financials GOOGL,Alphabet (Class A),Communication Services GOOG,Alphabet (Class C),Communication Services MO,Altria Group,Consumer Staples AMZN,Amazon,Consumer Discretionary AMCR,Amcor,Materials AEE,Ameren Corp,Utilities AAL,American Airlines Group,Industrials AEP,American Electric Power,Utilities AXP,American Express,Financials AIG,American International Group,Financials AMT,American Tower,Real Estate AWK,American Water Works,Utilities AMP,Ameriprise Financial,Financials ABC,AmerisourceBergen,Health Care AME,Ametek,Industrials AMGN,Amgen,Health Care APH,Amphenol,Information Technology ADI,Analog Devices,Information Technology ANSS,Ansys,Information Technology ANTM,Anthem,Health Care AON,Aon,Financials APA,APA Corporation,Energy AAPL,Apple,Information Technology AMAT,Applied Materials,Information Technology APTV,Aptiv,Consumer Discretionary ANET,Arista Networks,Information Technology AJG,Arthur J. Gallagher & Co.,Financials AIZ,Assurant,Financials T,AT&T,Communication Services ATO,Atmos Energy,Utilities ADSK,Autodesk,Information Technology ADP,Automatic Data Processing,Information Technology AZO,AutoZone,Consumer Discretionary AVB,AvalonBay Communities,Real Estate AVY,Avery Dennison,Materials BKR,Baker Hughes,Energy BLL,Ball Corp,Materials BAC,Bank of America,Financials BBWI,Bath & Body Works Inc.,Consumer Discretionary BAX,Baxter International,Health Care BDX,Becton Dickinson,Health Care BRK-B,Berkshire Hathaway,Financials BBY,Best Buy,Consumer Discretionary BIO,Bio-Rad Laboratories,Health Care TECH,Bio-Techne,Health Care BIIB,Biogen,Health Care BLK,BlackRock,Financials BK,BNY Mellon,Financials BA,Boeing,Industrials BKNG,Booking Holdings,Consumer Discretionary BWA,BorgWarner,Consumer Discretionary BXP,Boston Properties,Real Estate BSX,Boston Scientific,Health Care BMY,Bristol Myers Squibb,Health Care AVGO,Broadcom,Information Technology BR,Broadridge Financial Solutions,Information Technology BRO,Brown & Brown,Financials BF-B,Brown–Forman,Consumer Staples CHRW,C. H. Robinson,Industrials CDNS,Cadence Design Systems,Information Technology CZR,Caesars Entertainment,Consumer Discretionary CPB,Campbell Soup,Consumer Staples COF,Capital One Financial,Financials CAH,Cardinal Health,Health Care KMX,CarMax,Consumer Discretionary CCL,Carnival Corporation,Consumer Discretionary CARR,Carrier Global,Industrials CTLT,Catalent,Health Care CAT,Caterpillar,Industrials CBOE,Cboe Global Markets,Financials CBRE,CBRE,Real Estate CDW,CDW,Information Technology CE,Celanese,Materials CNC,Centene Corporation,Health Care CNP,CenterPoint Energy,Utilities CDAY,Ceridian,Information Technology CERN,Cerner,Health Care CF,CF Industries,Materials CRL,Charles River Laboratories,Health Care SCHW,Charles Schwab Corporation,Financials CHTR,Charter Communications,Communication Services CVX,Chevron Corporation,Energy CMG,Chipotle Mexican Grill,Consumer Discretionary CB,Chubb,Financials CHD,Church & Dwight,Consumer Staples CI,Cigna,Health Care CINF,Cincinnati Financial,Financials CTAS,Cintas Corporation,Industrials CSCO,Cisco Systems,Information Technology C,Citigroup,Financials CFG,Citizens Financial Group,Financials CTXS,Citrix Systems,Information Technology CLX,Clorox,Consumer Staples CME,CME Group,Financials CMS,CMS Energy,Utilities KO,Coca-Cola Company,Consumer Staples CTSH,Cognizant Technology Solutions,Information Technology CL,Colgate-Palmolive,Consumer Staples CMCSA,Comcast,Communication Services CMA,Comerica,Financials CAG,Conagra Brands,Consumer Staples COP,ConocoPhillips,Energy ED,Consolidated Edison,Utilities STZ,Constellation Brands,Consumer Staples CPRT,Copart,Industrials GLW,Corning,Information Technology CTVA,Corteva,Materials COST,Costco,Consumer Staples CTRA,Coterra,Energy CCI,Crown Castle,Real Estate CSX,CSX,Industrials CMI,Cummins,Industrials CVS,CVS Health,Health Care DHI,D. R. Horton,Consumer Discretionary DHR,Danaher Corporation,Health Care DRI,Darden Restaurants,Consumer Discretionary DVA,DaVita,Health Care DE,Deere & Co.,Industrials DAL,Delta Air Lines,Industrials XRAY,Dentsply Sirona,Health Care DVN,Devon Energy,Energy DXCM,DexCom,Health Care FANG,Diamondback Energy,Energy DLR,Digital Realty Trust,Real Estate DFS,Discover Financial Services,Financials DISCA,Discovery (Series A),Communication Services DISCK,Discovery (Series C),Communication Services DISH,Dish Network,Communication Services DG,Dollar General,Consumer Discretionary DLTR,Dollar Tree,Consumer Discretionary D,Dominion Energy,Utilities DPZ,Domino's Pizza,Consumer Discretionary DOV,Dover Corporation,Industrials DOW,Dow,Materials DTE,DTE Energy,Utilities DUK,Duke Energy,Utilities DRE,Duke Realty Corp,Real Estate DD,DuPont,Materials DXC,DXC Technology,Information Technology EMN,Eastman Chemical,Materials ETN,Eaton Corporation,Industrials EBAY,eBay,Consumer Discretionary ECL,Ecolab,Materials EIX,Edison International,Utilities EW,Edwards Lifesciences,Health Care EA,Electronic Arts,Communication Services LLY,Eli Lilly & Co,Health Care EMR,Emerson Electric Company,Industrials ENPH,Enphase Energy,Information Technology ETR,Entergy,Utilities EOG,EOG Resources,Energy EFX,Equifax,Industrials EQIX,Equinix,Real Estate EQR,Equity Residential,Real Estate ESS,Essex Property Trust,Real Estate EL,Estée Lauder Companies,Consumer Staples ETSY,Etsy,Consumer Discretionary RE,Everest Re,Financials EVRG,Evergy,Utilities ES,Eversource Energy,Utilities EXC,Exelon,Utilities EXPE,Expedia Group,Consumer Discretionary EXPD,Expeditors,Industrials EXR,Extra Space Storage,Real Estate XOM,ExxonMobil,Energy FFIV,F5 Networks,Information Technology FB,Facebook,Communication Services FAST,Fastenal,Industrials FRT,Federal Realty Investment Trust,Real Estate FDX,FedEx,Industrials FIS,Fidelity National Information Services,Information Technology FITB,Fifth Third Bancorp,Financials FRC,First Republic Bank,Financials FE,FirstEnergy,Utilities FISV,Fiserv,Information Technology FLT,Fleetcor,Information Technology FMC,FMC Corporation,Materials F,Ford,Consumer Discretionary FTNT,Fortinet,Information Technology FTV,Fortive,Industrials FBHS,Fortune Brands Home & Security,Industrials FOXA,Fox Corporation (Class A),Communication Services FOX,Fox Corporation (Class B),Communication Services BEN,Franklin Resources,Financials FCX,Freeport-McMoRan,Materials GPS,Gap,Consumer Discretionary GRMN,Garmin,Consumer Discretionary IT,Gartner,Information Technology GNRC,Generac Holdings,Industrials GD,General Dynamics,Industrials GE,General Electric,Industrials GIS,General Mills,Consumer Staples GM,General Motors,Consumer Discretionary GPC,Genuine Parts,Consumer Discretionary GILD,Gilead Sciences,Health Care GPN,Global Payments,Information Technology GL,Globe Life,Financials GS,Goldman Sachs,Financials HAL,Halliburton,Energy HBI,Hanesbrands,Consumer Discretionary HAS,Hasbro,Consumer Discretionary HCA,HCA Healthcare,Health Care PEAK,Healthpeak Properties,Real Estate HSIC,Henry Schein,Health Care HES,Hess Corporation,Energy HPE,Hewlett Packard Enterprise,Information Technology HLT,Hilton Worldwide,Consumer Discretionary HOLX,Hologic,Health Care HD,Home Depot,Consumer Discretionary HON,Honeywell,Industrials HRL,Hormel,Consumer Staples HST,Host Hotels & Resorts,Real Estate HWM,Howmet Aerospace,Industrials HPQ,HP,Information Technology HUM,Humana,Health Care HBAN,Huntington Bancshares,Financials HII,Huntington Ingalls Industries,Industrials IBM,IBM,Information Technology IEX,IDEX Corporation,Industrials IDXX,Idexx Laboratories,Health Care INFO,IHS Markit,Industrials ITW,Illinois Tool Works,Industrials ILMN,Illumina,Health Care INCY,Incyte,Health Care IR,Ingersoll Rand,Industrials INTC,Intel,Information Technology ICE,Intercontinental Exchange,Financials IFF,International Flavors & Fragrances,Materials IP,International Paper,Materials IPG,Interpublic Group,Communication Services INTU,Intuit,Information Technology ISRG,Intuitive Surgical,Health Care IVZ,Invesco,Financials IPGP,IPG Photonics,Information Technology IQV,IQVIA,Health Care IRM,Iron Mountain,Real Estate JBHT,J. B. Hunt,Industrials JKHY,Jack Henry & Associates,Information Technology J,Jacobs Engineering Group,Industrials SJM,JM Smucker,Consumer Staples JNJ,Johnson & Johnson,Health Care JCI,Johnson Controls,Industrials JPM,JPMorgan Chase,Financials JNPR,Juniper Networks,Information Technology KSU,Kansas City Southern,Industrials K,Kellogg's,Consumer Staples KEY,KeyCorp,Financials KEYS,Keysight Technologies,Information Technology KMB,Kimberly-Clark,Consumer Staples KIM,Kimco Realty,Real Estate KMI,Kinder Morgan,Energy KLAC,KLA Corporation,Information Technology KHC,Kraft Heinz,Consumer Staples KR,Kroger,Consumer Staples LHX,L3Harris Technologies,Industrials LH,LabCorp,Health Care LRCX,Lam Research,Information Technology LW,Lamb Weston,Consumer Staples LVS,Las Vegas Sands,Consumer Discretionary LEG,Leggett & Platt,Consumer Discretionary LDOS,Leidos,Industrials LEN,Lennar,Consumer Discretionary LNC,Lincoln National,Financials LIN,Linde,Materials LYV,Live Nation Entertainment,Communication Services LKQ,LKQ Corporation,Consumer Discretionary LMT,Lockheed Martin,Industrials L,Loews Corporation,Financials LOW,Lowe's,Consumer Discretionary LUMN,Lumen Technologies,Communication Services LYB,LyondellBasell,Materials MTB,M&T Bank,Financials MRO,Marathon Oil,Energy MPC,Marathon Petroleum,Energy MKTX,MarketAxess,Financials MAR,Marriott International,Consumer Discretionary MMC,Marsh & McLennan,Financials MLM,Martin Marietta Materials,Materials MAS,Masco,Industrials MA,Mastercard,Information Technology MTCH,Match Group,Communication Services MKC,McCormick & Company,Consumer Staples MCD,McDonald's,Consumer Discretionary MCK,McKesson Corporation,Health Care MDT,Medtronic,Health Care MRK,Merck & Co.,Health Care MET,MetLife,Financials MTD,Mettler Toledo,Health Care MGM,MGM Resorts International,Consumer Discretionary MCHP,Microchip Technology,Information Technology MU,Micron Technology,Information Technology MSFT,Microsoft,Information Technology MAA,Mid-America Apartments,Real Estate MRNA,Moderna,Health Care MHK,Mohawk Industries,Consumer Discretionary TAP,Molson Coors Beverage Company,Consumer Staples MDLZ,Mondelez International,Consumer Staples MPWR,Monolithic Power Systems,Information Technology MNST,Monster Beverage,Consumer Staples MCO,Moody's Corporation,Financials MS,Morgan Stanley,Financials MSI,Motorola Solutions,Information Technology MSCI,MSCI,Financials NDAQ,Nasdaq,Financials NTAP,NetApp,Information Technology NFLX,Netflix,Communication Services NWL,Newell Brands,Consumer Discretionary NEM,Newmont,Materials NWSA,News Corp (Class A),Communication Services NWS,News Corp (Class B),Communication Services NEE,NextEra Energy,Utilities NLSN,Nielsen Holdings,Industrials NKE,Nike,Consumer Discretionary NI,NiSource,Utilities NSC,Norfolk Southern,Industrials NTRS,Northern Trust,Financials NOC,Northrop Grumman,Industrials NLOK,NortonLifeLock,Information Technology NCLH,Norwegian Cruise Line Holdings,Consumer Discretionary NRG,NRG Energy,Utilities NUE,Nucor,Materials NVDA,Nvidia,Information Technology NVR,NVR,Consumer Discretionary NXPI,NXP,Information Technology ORLY,O'Reilly Automotive,Consumer Discretionary OXY,Occidental Petroleum,Energy ODFL,Old Dominion Freight Line,Industrials OMC,Omnicom Group,Communication Services OKE,Oneok,Energy ORCL,Oracle,Information Technology OGN,Organon & Co.,Health Care OTIS,Otis Worldwide,Industrials PCAR,Paccar,Industrials PKG,Packaging Corporation of America,Materials PH,Parker-Hannifin,Industrials PAYX,Paychex,Information Technology PAYC,Paycom,Information Technology PYPL,PayPal,Information Technology PENN,Penn National Gaming,Consumer Discretionary PNR,Pentair,Industrials PBCT,People's United Financial,Financials PEP,PepsiCo,Consumer Staples PKI,PerkinElmer,Health Care PFE,Pfizer,Health Care PM,Philip Morris International,Consumer Staples PSX,Phillips 66,Energy PNW,Pinnacle West Capital,Utilities PXD,Pioneer Natural Resources,Energy PNC,PNC Financial Services,Financials POOL,Pool Corporation,Consumer Discretionary PPG,PPG Industries,Materials PPL,PPL,Utilities PFG,Principal Financial Group,Financials PG,Procter & Gamble,Consumer Staples PGR,Progressive Corporation,Financials PLD,Prologis,Real Estate PRU,Prudential Financial,Financials PTC,PTC,Information Technology PEG,Public Service Enterprise Group,Utilities PSA,Public Storage,Real Estate PHM,PulteGroup,Consumer Discretionary PVH,PVH,Consumer Discretionary QRVO,Qorvo,Information Technology QCOM,Qualcomm,Information Technology PWR,Quanta Services,Industrials DGX,Quest Diagnostics,Health Care RL,Ralph Lauren Corporation,Consumer Discretionary RJF,Raymond James Financial,Financials RTX,Raytheon Technologies,Industrials O,Realty Income Corporation,Real Estate REG,Regency Centers,Real Estate REGN,Regeneron Pharmaceuticals,Health Care RF,Regions Financial Corporation,Financials RSG,Republic Services,Industrials RMD,ResMed,Health Care RHI,Robert Half International,Industrials ROK,Rockwell Automation,Industrials ROL,Rollins,Industrials ROP,Roper Technologies,Industrials ROST,Ross Stores,Consumer Discretionary RCL,Royal Caribbean Group,Consumer Discretionary SPGI,S&P Global,Financials CRM,Salesforce,Information Technology SBAC,SBA Communications,Real Estate SLB,Schlumberger,Energy STX,Seagate Technology,Information Technology SEE,Sealed Air,Materials SRE,Sempra Energy,Utilities NOW,ServiceNow,Information Technology SHW,Sherwin-Williams,Materials SPG,Simon Property Group,Real Estate SWKS,Skyworks Solutions,Information Technology SNA,Snap-on,Industrials SO,Southern Company,Utilities LUV,Southwest Airlines,Industrials SWK,Stanley Black & Decker,Industrials SBUX,Starbucks,Consumer Discretionary STT,State Street Corporation,Financials STE,Steris,Health Care SYK,Stryker Corporation,Health Care SIVB,SVB Financial,Financials SYF,Synchrony Financial,Financials SNPS,Synopsys,Information Technology SYY,Sysco,Consumer Staples TMUS,T-Mobile US,Communication Services TROW,T. Rowe Price,Financials TTWO,Take-Two Interactive,Communication Services TPR,Tapestry,Consumer Discretionary TGT,Target Corporation,Consumer Discretionary TEL,TE Connectivity,Information Technology TDY,Teledyne Technologies,Industrials TFX,Teleflex,Health Care TER,Teradyne,Information Technology TSLA,Tesla,Consumer Discretionary TXN,Texas Instruments,Information Technology TXT,Textron,Industrials COO,The Cooper Companies,Health Care HIG,The Hartford,Financials HSY,The Hershey Company,Consumer Staples MOS,The Mosaic Company,Materials TRV,The Travelers Companies,Financials DIS,The Walt Disney Company,Communication Services TMO,Thermo Fisher Scientific,Health Care TJX,TJX Companies,Consumer Discretionary TSCO,Tractor Supply Company,Consumer Discretionary TT,Trane Technologies,Industrials TDG,TransDigm Group,Industrials TRMB,Trimble,Information Technology TFC,Truist Financial,Financials TWTR,Twitter,Communication Services TYL,Tyler Technologies,Information Technology TSN,Tyson Foods,Consumer Staples USB,U.S. Bancorp,Financials UDR,UDR,Real Estate ULTA,Ulta Beauty,Consumer Discretionary UAA,Under Armour (Class A),Consumer Discretionary UA,Under Armour (Class C),Consumer Discretionary UNP,Union Pacific,Industrials UAL,United Airlines,Industrials UPS,United Parcel Service,Industrials URI,United Rentals,Industrials UNH,UnitedHealth Group,Health Care UHS,Universal Health Services,Health Care VLO,Valero Energy,Energy VTR,Ventas,Real Estate VRSN,Verisign,Information Technology VRSK,Verisk Analytics,Industrials VZ,Verizon Communications,Communication Services VRTX,Vertex Pharmaceuticals,Health Care VFC,VF Corporation,Consumer Discretionary VIAC,ViacomCBS,Communication Services VTRS,Viatris,Health Care V,Visa,Information Technology VNO,Vornado Realty Trust,Real Estate VMC,Vulcan Materials,Materials WRB,W. R. Berkley Corporation,Financials GWW,W. W. Grainger,Industrials WAB,Wabtec,Industrials WBA,Walgreens Boots Alliance,Consumer Staples WMT,Walmart,Consumer Staples WM,Waste Management,Industrials WAT,Waters Corporation,Health Care WEC,WEC Energy Group,Utilities WFC,Wells Fargo,Financials WELL,Welltower,Real Estate WST,West Pharmaceutical Services,Health Care WDC,Western Digital,Information Technology WU,Western Union,Information Technology WRK,WestRock,Materials WY,Weyerhaeuser,Real Estate WHR,Whirlpool Corporation,Consumer Discretionary WMB,Williams Companies,Energy WLTW,Willis Towers Watson,Financials WYNN,Wynn Resorts,Consumer Discretionary XEL,Xcel Energy,Utilities XLNX,Xilinx,Information Technology XYL,Xylem,Industrials YUM,Yum! Brands,Consumer Discretionary ZBRA,Zebra Technologies,Information Technology ZBH,Zimmer Biomet,Health Care ZION,Zions Bancorp,Financials ZTS,Zoetis,Health Care ================================================ FILE: requirements.txt ================================================ tqdm==4.64.1 pandas==2.0.3 pyyaml==6.0 numpy==1.22.1 matplotlib==3.5.1 ================================================ FILE: src/evaluator.py ================================================ import numpy as np import pandas as pd def evaluate(prediction, ground_truth, mask, report=False): assert ground_truth.shape == prediction.shape, 'shape mis-match' performance = {} # mse performance['mse'] = np.linalg.norm((prediction - ground_truth) * mask) ** 2 / np.sum(mask) # IC df_pred = pd.DataFrame(prediction * mask) df_gt = pd.DataFrame(ground_truth * mask) ic = [] mrr_top = 0.0 all_miss_days_top = 0 bt_long = 1.0 bt_long5 = 1.0 bt_long10 = 1.0 irr = 0.0 sharpe_li5 = [] prec_10 = [] for i in range(prediction.shape[1]): # IC ic.append(df_pred[i].corr(df_gt[i])) rank_gt = np.argsort(ground_truth[:, i]) gt_top1 = set() gt_top5 = set() gt_top10 = set() for j in range(1, prediction.shape[0] + 1): cur_rank = rank_gt[-1 * j] if mask[cur_rank][i] < 0.5: continue if len(gt_top1) < 1: gt_top1.add(cur_rank) if len(gt_top5) < 5: gt_top5.add(cur_rank) if len(gt_top10) < 10: gt_top10.add(cur_rank) rank_pre = np.argsort(prediction[:, i]) pre_top1 = set() pre_top5 = set() pre_top10 = set() for j in range(1, prediction.shape[0] + 1): cur_rank = rank_pre[-1 * j] if mask[cur_rank][i] < 0.5: continue if len(pre_top1) < 1: pre_top1.add(cur_rank) if len(pre_top5) < 5: pre_top5.add(cur_rank) if len(pre_top10) < 10: pre_top10.add(cur_rank) top1_pos_in_gt = 0 for j in range(1, prediction.shape[0] + 1): cur_rank = rank_gt[-1 * j] if mask[cur_rank][i] < 0.5: continue else: top1_pos_in_gt += 1 if cur_rank in pre_top1: break if top1_pos_in_gt == 0: all_miss_days_top += 1 else: mrr_top += 1.0 / top1_pos_in_gt real_ret_rat_top = ground_truth[list(pre_top1)[0]][i] bt_long += real_ret_rat_top gt_irr = 0.0 for gt in gt_top10: gt_irr += ground_truth[gt][i] real_ret_rat_top5 = 0 for pre in pre_top5: real_ret_rat_top5 += ground_truth[pre][i] irr += real_ret_rat_top5 real_ret_rat_top5 /= 5 bt_long5 += real_ret_rat_top5 prec = 0.0 real_ret_rat_top10 = 0 for pre in pre_top10: real_ret_rat_top10 += ground_truth[pre][i] prec += (ground_truth[pre][i] >= 0) prec_10.append(prec / 10) real_ret_rat_top10 /= 10 bt_long10 += real_ret_rat_top10 sharpe_li5.append(real_ret_rat_top5) performance['IC'] = np.mean(ic) performance['RIC'] = np.mean(ic) / np.std(ic) sharpe_li5 = np.array(sharpe_li5) performance['sharpe5'] = (np.mean(sharpe_li5)/np.std(sharpe_li5))*15.87 performance['prec_10'] = np.mean(prec_10) return performance ================================================ FILE: src/load_data.py ================================================ import numpy as np import os from tqdm import tqdm def load_EOD_data(data_path, market_name, tickers, steps=1): eod_data = [] masks = [] ground_truth = [] base_price = [] for index, ticker in enumerate(tqdm(tickers)): single_EOD = np.genfromtxt( os.path.join(data_path, market_name + '_' + ticker + '_1.csv'), dtype=np.float32, delimiter=',', skip_header=False ) if market_name == 'NASDAQ': # remove the last day since lots of missing data single_EOD = single_EOD[:-1, :] if index == 0: print('single EOD data shape:', single_EOD.shape) eod_data = np.zeros([len(tickers), single_EOD.shape[0], single_EOD.shape[1] - 1], dtype=np.float32) masks = np.ones([len(tickers), single_EOD.shape[0]], dtype=np.float32) ground_truth = np.zeros([len(tickers), single_EOD.shape[0]], dtype=np.float32) base_price = np.zeros([len(tickers), single_EOD.shape[0]], dtype=np.float32) for row in range(single_EOD.shape[0]): if abs(single_EOD[row][-1] + 1234) < 1e-8: masks[index][row] = 0.0 elif row > steps - 1 and abs(single_EOD[row - steps][-1] + 1234) \ > 1e-8: ground_truth[index][row] = \ (single_EOD[row][-1] - single_EOD[row - steps][-1]) / \ single_EOD[row - steps][-1] for col in range(single_EOD.shape[1]): if abs(single_EOD[row][col] + 1234) < 1e-8: single_EOD[row][col] = 1.1 eod_data[index, :, :] = single_EOD[:, 1:] base_price[index, :] = single_EOD[:, -1] return eod_data, masks, ground_truth, base_price def load_graph_relation_data(relation_file, lap=False): relation_encoding = np.load(relation_file) print('relation encoding shape:', relation_encoding.shape) rel_shape = [relation_encoding.shape[0], relation_encoding.shape[1]] mask_flags = np.equal(np.zeros(rel_shape, dtype=int), np.sum(relation_encoding, axis=2)) ajacent = np.where(mask_flags, np.zeros(rel_shape, dtype=float), np.ones(rel_shape, dtype=float)) degree = np.sum(ajacent, axis=0) for i in range(len(degree)): degree[i] = 1.0 / degree[i] np.sqrt(degree, degree) deg_neg_half_power = np.diag(degree) if lap: return np.identity(ajacent.shape[0], dtype=float) - np.dot( np.dot(deg_neg_half_power, ajacent), deg_neg_half_power) else: return np.dot(np.dot(deg_neg_half_power, ajacent), deg_neg_half_power) def load_relation_data(relation_file): relation_encoding = np.load(relation_file) rel_shape = [relation_encoding.shape[0], relation_encoding.shape[1]] mask_flags = np.equal(np.zeros(rel_shape, dtype=int), np.sum(relation_encoding, axis=2)) mask = np.where(mask_flags, np.ones(rel_shape) * -1e9, np.zeros(rel_shape)) return relation_encoding, mask def build_SFM_data(data_path, market_name, tickers): eod_data = [] for index, ticker in enumerate(tickers): single_EOD = np.genfromtxt( os.path.join(data_path, market_name + '_' + ticker + '_1.csv'), dtype=np.float32, delimiter=',', skip_header=False ) if index == 0: print('single EOD data shape:', single_EOD.shape) eod_data = np.zeros([len(tickers), single_EOD.shape[0]], dtype=np.float32) for row in range(single_EOD.shape[0]): if abs(single_EOD[row][-1] + 1234) < 1e-8: if row < 3: for i in range(row + 1, single_EOD.shape[0]): if abs(single_EOD[i][-1] + 1234) > 1e-8: eod_data[index][row] = single_EOD[i][-1] break else: eod_data[index][row] = np.sum( eod_data[index, row - 3:row]) / 3 else: eod_data[index][row] = single_EOD[row][-1] np.save(market_name + '_sfm_data', eod_data) ================================================ FILE: src/model.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F acv = nn.GELU() def get_loss(prediction, ground_truth, base_price, mask, batch_size, alpha): device = prediction.device all_one = torch.ones(batch_size, 1, dtype=torch.float32).to(device) return_ratio = torch.div(torch.sub(prediction, base_price), base_price) reg_loss = F.mse_loss(return_ratio * mask, ground_truth * mask) pre_pw_dif = torch.sub( return_ratio @ all_one.t(), all_one @ return_ratio.t() ) gt_pw_dif = torch.sub( all_one @ ground_truth.t(), ground_truth @ all_one.t() ) mask_pw = mask @ mask.t() rank_loss = torch.mean( F.relu(pre_pw_dif * gt_pw_dif * mask_pw) ) loss = reg_loss + alpha * rank_loss return loss, reg_loss, rank_loss, return_ratio class MixerBlock(nn.Module): def __init__(self, mlp_dim, hidden_dim, dropout=0.0): super(MixerBlock, self).__init__() self.mlp_dim = mlp_dim self.dropout = dropout self.dense_1 = nn.Linear(mlp_dim, hidden_dim) self.LN = acv self.dense_2 = nn.Linear(hidden_dim, mlp_dim) def forward(self, x): x = self.dense_1(x) x = self.LN(x) if self.dropout != 0.0: x = F.dropout(x, p=self.dropout) x = self.dense_2(x) if self.dropout != 0.0: x = F.dropout(x, p=self.dropout) return x class Mixer2d(nn.Module): def __init__(self, time_steps, channels): super(Mixer2d, self).__init__() self.LN_1 = nn.LayerNorm([time_steps, channels]) self.LN_2 = nn.LayerNorm([time_steps, channels]) self.timeMixer = MixerBlock(time_steps, time_steps) self.channelMixer = MixerBlock(channels, channels) def forward(self, inputs): x = self.LN_1(inputs) x = x.permute(0, 2, 1) x = self.timeMixer(x) x = x.permute(0, 2, 1) x = self.LN_2(x + inputs) y = self.channelMixer(x) return x + y class TriU(nn.Module): def __init__(self, time_step): super(TriU, self).__init__() self.time_step = time_step self.triU = nn.ParameterList( [ nn.Linear(i + 1, 1) for i in range(time_step) ] ) def forward(self, inputs): x = self.triU[0](inputs[:, :, 0].unsqueeze(-1)) for i in range(1, self.time_step): x = torch.cat([x, self.triU[i](inputs[:, :, 0:i + 1])], dim=-1) return x class TimeMixerBlock(nn.Module): def __init__(self, time_step): super(TimeMixerBlock, self).__init__() self.time_step = time_step self.dense_1 = TriU(time_step) self.LN = acv self.dense_2 = TriU(time_step) def forward(self, x): x = self.dense_1(x) x = self.LN(x) x = self.dense_2(x) return x class MultiScaleTimeMixer(nn.Module): def __init__(self, time_step, channel, scale_count=1): super(MultiScaleTimeMixer, self).__init__() self.time_step = time_step self.scale_count = scale_count self.mix_layer = nn.ParameterList([nn.Sequential( nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=2 ** i, stride=2 ** i), TriU(int(time_step / 2 ** i)), nn.Hardswish(), TriU(int(time_step / 2 ** i)) ) for i in range(scale_count)]) self.mix_layer[0] = nn.Sequential( nn.LayerNorm([time_step, channel]), TriU(int(time_step)), nn.Hardswish(), TriU(int(time_step)) ) def forward(self, x): x = x.permute(0, 2, 1) y = self.mix_layer[0](x) for i in range(1, self.scale_count): y = torch.cat((y, self.mix_layer[i](x)), dim=-1) return y class Mixer2dTriU(nn.Module): def __init__(self, time_steps, channels): super(Mixer2dTriU, self).__init__() self.LN_1 = nn.LayerNorm([time_steps, channels]) self.LN_2 = nn.LayerNorm([time_steps, channels]) self.timeMixer = TriU(time_steps) self.channelMixer = MixerBlock(channels, channels) def forward(self, inputs): x = self.LN_1(inputs) x = x.permute(0, 2, 1) x = self.timeMixer(x) x = x.permute(0, 2, 1) x = self.LN_2(x + inputs) y = self.channelMixer(x) return x + y class MultTime2dMixer(nn.Module): def __init__(self, time_step, channel, scale_dim=8): super(MultTime2dMixer, self).__init__() self.mix_layer = Mixer2dTriU(time_step, channel) self.scale_mix_layer = Mixer2dTriU(scale_dim, channel) def forward(self, inputs, y): y = self.scale_mix_layer(y) x = self.mix_layer(inputs) return torch.cat([inputs, x, y], dim=1) class NoGraphMixer(nn.Module): def __init__(self, stocks, hidden_dim=20): super(NoGraphMixer, self).__init__() self.dense1 = nn.Linear(stocks, hidden_dim) self.activation = nn.Hardswish() self.dense2 = nn.Linear(hidden_dim, stocks) self.layer_norm_stock = nn.LayerNorm(stocks) def forward(self, inputs): x = inputs x = x.permute(1, 0) x = self.layer_norm_stock(x) x = self.dense1(x) x = self.activation(x) x = self.dense2(x) x = x.permute(1, 0) return x class StockMixer(nn.Module): def __init__(self, stocks, time_steps, channels, market, scale): super(StockMixer, self).__init__() scale_dim = 8 self.mixer = MultTime2dMixer(time_steps, channels, scale_dim=scale_dim) self.channel_fc = nn.Linear(channels, 1) self.time_fc = nn.Linear(time_steps * 2 + scale_dim, 1) self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=2, stride=2) self.stock_mixer = NoGraphMixer(stocks, market) self.time_fc_ = nn.Linear(time_steps * 2 + scale_dim, 1) def forward(self, inputs): x = inputs.permute(0, 2, 1) x = self.conv(x) x = x.permute(0, 2, 1) y = self.mixer(inputs, x) y = self.channel_fc(y).squeeze(-1) z = self.stock_mixer(y) y = self.time_fc(y) z = self.time_fc_(z) return y + z ================================================ FILE: src/train.py ================================================ import random import numpy as np import os import torch as torch from load_data import load_EOD_data from evaluator import evaluate from model import get_loss, StockMixer import pickle np.random.seed(123456789) torch.random.manual_seed(12345678) device = torch.device("cuda") if torch.cuda.is_available() else 'cpu' data_path = '../dataset' market_name = 'NASDAQ' relation_name = 'wikidata' stock_num = 1026 lookback_length = 16 epochs = 100 valid_index = 756 test_index = 1008 fea_num = 5 market_num = 20 steps = 1 learning_rate = 0.001 alpha = 0.1 scale_factor = 3 activation = 'GELU' dataset_path = '../dataset/' + market_name if market_name == "SP500": data = np.load('../dataset/SP500/SP500.npy') data = data[:, 915:, :] price_data = data[:, :, -1] mask_data = np.ones((data.shape[0], data.shape[1])) eod_data = data gt_data = np.zeros((data.shape[0], data.shape[1])) for ticket in range(0, data.shape[0]): for row in range(1, data.shape[1]): gt_data[ticket][row] = (data[ticket][row][-1] - data[ticket][row - steps][-1]) / \ data[ticket][row - steps][-1] else: with open(os.path.join(dataset_path, "eod_data.pkl"), "rb") as f: eod_data = pickle.load(f) with open(os.path.join(dataset_path, "mask_data.pkl"), "rb") as f: mask_data = pickle.load(f) with open(os.path.join(dataset_path, "gt_data.pkl"), "rb") as f: gt_data = pickle.load(f) with open(os.path.join(dataset_path, "price_data.pkl"), "rb") as f: price_data = pickle.load(f) trade_dates = mask_data.shape[1] model = StockMixer( stocks=stock_num, time_steps=lookback_length, channels=fea_num, market=market_num, scale=scale_factor ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) best_valid_loss = np.inf best_valid_perf = None best_test_perf = None batch_offsets = np.arange(start=0, stop=valid_index, dtype=int) def validate(start_index, end_index): with torch.no_grad(): cur_valid_pred = np.zeros([stock_num, end_index - start_index], dtype=float) cur_valid_gt = np.zeros([stock_num, end_index - start_index], dtype=float) cur_valid_mask = np.zeros([stock_num, end_index - start_index], dtype=float) loss = 0. reg_loss = 0. rank_loss = 0. for cur_offset in range(start_index - lookback_length - steps + 1, end_index - lookback_length - steps + 1): data_batch, mask_batch, price_batch, gt_batch = map( lambda x: torch.Tensor(x).to(device), get_batch(cur_offset) ) prediction = model(data_batch) cur_loss, cur_reg_loss, cur_rank_loss, cur_rr = get_loss(prediction, gt_batch, price_batch, mask_batch, stock_num, alpha) loss += cur_loss.item() reg_loss += cur_reg_loss.item() rank_loss += cur_rank_loss.item() cur_valid_pred[:, cur_offset - (start_index - lookback_length - steps + 1)] = cur_rr[:, 0].cpu() cur_valid_gt[:, cur_offset - (start_index - lookback_length - steps + 1)] = gt_batch[:, 0].cpu() cur_valid_mask[:, cur_offset - (start_index - lookback_length - steps + 1)] = mask_batch[:, 0].cpu() loss = loss / (end_index - start_index) reg_loss = reg_loss / (end_index - start_index) rank_loss = rank_loss / (end_index - start_index) cur_valid_perf = evaluate(cur_valid_pred, cur_valid_gt, cur_valid_mask) return loss, reg_loss, rank_loss, cur_valid_perf def get_batch(offset=None): if offset is None: offset = random.randrange(0, valid_index) seq_len = lookback_length mask_batch = mask_data[:, offset: offset + seq_len + steps] mask_batch = np.min(mask_batch, axis=1) return ( eod_data[:, offset:offset + seq_len, :], np.expand_dims(mask_batch, axis=1), np.expand_dims(price_data[:, offset + seq_len - 1], axis=1), np.expand_dims(gt_data[:, offset + seq_len + steps - 1], axis=1)) for epoch in range(epochs): print("epoch{}##########################################################".format(epoch + 1)) np.random.shuffle(batch_offsets) tra_loss = 0.0 tra_reg_loss = 0.0 tra_rank_loss = 0.0 for j in range(valid_index - lookback_length - steps + 1): data_batch, mask_batch, price_batch, gt_batch = map( lambda x: torch.Tensor(x).to(device), get_batch(batch_offsets[j]) ) optimizer.zero_grad() prediction = model(data_batch) cur_loss, cur_reg_loss, cur_rank_loss, _ = get_loss(prediction, gt_batch, price_batch, mask_batch, stock_num, alpha) cur_loss = cur_loss cur_loss.backward() optimizer.step() tra_loss += cur_loss.item() tra_reg_loss += cur_reg_loss.item() tra_rank_loss += cur_rank_loss.item() tra_loss = tra_loss / (valid_index - lookback_length - steps + 1) tra_reg_loss = tra_reg_loss / (valid_index - lookback_length - steps + 1) tra_rank_loss = tra_rank_loss / (valid_index - lookback_length - steps + 1) print('Train : loss:{:.2e} = {:.2e} + alpha*{:.2e}'.format(tra_loss, tra_reg_loss, tra_rank_loss)) val_loss, val_reg_loss, val_rank_loss, val_perf = validate(valid_index, test_index) print('Valid : loss:{:.2e} = {:.2e} + alpha*{:.2e}'.format(val_loss, val_reg_loss, val_rank_loss)) test_loss, test_reg_loss, test_rank_loss, test_perf = validate(test_index, trade_dates) print('Test: loss:{:.2e} = {:.2e} + alpha*{:.2e}'.format(test_loss, test_reg_loss, test_rank_loss)) if val_loss < best_valid_loss: best_valid_loss = val_loss best_valid_perf = val_perf best_test_perf = test_perf print('Valid performance:\n', 'mse:{:.2e}, IC:{:.2e}, RIC:{:.2e}, prec@10:{:.2e}, SR:{:.2e}'.format(val_perf['mse'], val_perf['IC'], val_perf['RIC'], val_perf['prec_10'], val_perf['sharpe5'])) print('Test performance:\n', 'mse:{:.2e}, IC:{:.2e}, RIC:{:.2e}, prec@10:{:.2e}, SR:{:.2e}'.format(test_perf['mse'], test_perf['IC'], test_perf['RIC'], test_perf['prec_10'], test_perf['sharpe5']), '\n\n')