Repository: DemonDamon/Listed-company-news-crawl-and-text-analysis Branch: main Commit: d7a20a1f7ee8 Files: 293 Total size: 2.5 MB Directory structure: gitextract_w0u594fz/ ├── .deepsource.toml ├── .gitignore ├── LICENSE ├── README.md ├── README_zn.md ├── backend/ │ ├── .gitignore │ ├── README.md │ ├── README_zn.md │ ├── add_raw_html_column.py │ ├── app/ │ │ ├── __init__.py │ │ ├── agents/ │ │ │ ├── __init__.py │ │ │ ├── data_collector.py │ │ │ ├── data_collector_v2.py │ │ │ ├── debate_agents.py │ │ │ ├── news_analyst.py │ │ │ ├── orchestrator.py │ │ │ ├── quantitative_agent.py │ │ │ └── search_analyst.py │ │ ├── alpha_mining/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── backtest/ │ │ │ │ ├── __init__.py │ │ │ │ └── evaluator.py │ │ │ ├── config.py │ │ │ ├── dsl/ │ │ │ │ ├── __init__.py │ │ │ │ ├── ops.py │ │ │ │ └── vocab.py │ │ │ ├── features/ │ │ │ │ ├── __init__.py │ │ │ │ ├── market.py │ │ │ │ └── sentiment.py │ │ │ ├── model/ │ │ │ │ ├── __init__.py │ │ │ │ ├── alpha_generator.py │ │ │ │ └── trainer.py │ │ │ ├── tools/ │ │ │ │ ├── __init__.py │ │ │ │ └── alpha_mining_tool.py │ │ │ ├── utils.py │ │ │ └── vm/ │ │ │ ├── __init__.py │ │ │ └── factor_vm.py │ │ ├── api/ │ │ │ ├── __init__.py │ │ │ └── v1/ │ │ │ ├── __init__.py │ │ │ ├── agents.py │ │ │ ├── alpha_mining.py │ │ │ ├── analysis.py │ │ │ ├── debug.py │ │ │ ├── knowledge_graph.py │ │ │ ├── llm_config.py │ │ │ ├── news.py │ │ │ ├── news_v2.py │ │ │ ├── stocks.py │ │ │ └── tasks.py │ │ ├── config/ │ │ │ ├── __init__.py │ │ │ └── debate_modes.yaml │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── celery_app.py │ │ │ ├── config.py │ │ │ ├── database.py │ │ │ ├── neo4j_client.py │ │ │ └── redis_client.py │ │ ├── financial/ │ │ │ ├── __init__.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ ├── news.py │ │ │ │ └── stock.py │ │ │ ├── providers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── eastmoney/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fetchers/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── news.py │ │ │ │ │ └── provider.py │ │ │ │ ├── nbd/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fetchers/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── news.py │ │ │ │ │ └── provider.py │ │ │ │ ├── netease/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fetchers/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── news.py │ │ │ │ │ └── provider.py │ │ │ │ ├── sina/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fetchers/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── news.py │ │ │ │ │ └── provider.py │ │ │ │ ├── tencent/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fetchers/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── news.py │ │ │ │ │ └── provider.py │ │ │ │ └── yicai/ │ │ │ │ ├── __init__.py │ │ │ │ ├── fetchers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── news.py │ │ │ │ └── provider.py │ │ │ ├── registry.py │ │ │ └── tools.py │ │ ├── knowledge/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── graph_models.py │ │ │ ├── graph_service.py │ │ │ ├── knowledge_extractor.py │ │ │ └── parallel_search.py │ │ ├── main.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── analysis.py │ │ │ ├── crawl_task.py │ │ │ ├── database.py │ │ │ ├── debate_history.py │ │ │ ├── news.py │ │ │ └── stock.py │ │ ├── scripts/ │ │ │ └── init_stocks.py │ │ ├── services/ │ │ │ ├── __init__.py │ │ │ ├── analysis_service.py │ │ │ ├── embedding_service.py │ │ │ ├── llm_service.py │ │ │ └── stock_data_service.py │ │ ├── storage/ │ │ │ ├── __init__.py │ │ │ └── vector_storage.py │ │ ├── tasks/ │ │ │ ├── __init__.py │ │ │ └── crawl_tasks.py │ │ └── tools/ │ │ ├── __init__.py │ │ ├── bochaai_search.py │ │ ├── caijing_crawler.py │ │ ├── crawler_base.py │ │ ├── crawler_enhanced.py │ │ ├── dynamic_crawler_example.py │ │ ├── eastmoney_crawler.py │ │ ├── eeo_crawler.py │ │ ├── interactive_crawler.py │ │ ├── jingji21_crawler.py │ │ ├── jwview_crawler.py │ │ ├── nbd_crawler.py │ │ ├── netease163_crawler.py │ │ ├── search_engine_crawler.py │ │ ├── sina_crawler.py │ │ ├── tencent_crawler.py │ │ ├── text_cleaner.py │ │ └── yicai_crawler.py │ ├── clear_news_data.py │ ├── env.example │ ├── init_db.py │ ├── init_knowledge_graph.py │ ├── requirements.txt │ ├── reset_database.py │ ├── setup_env.sh │ ├── start.sh │ ├── start_celery.sh │ └── tests/ │ ├── __init__.py │ ├── check_milvus_data.py │ ├── check_news_embedding_status.py │ ├── financial/ │ │ ├── __init__.py │ │ ├── test_smoke_openbb_models.py │ │ ├── test_smoke_openbb_provider.py │ │ └── test_smoke_openbb_tools.py │ ├── manual_vectorize.py │ ├── test_alpha_mining/ │ │ ├── __init__.py │ │ ├── test_integration_p2.py │ │ ├── test_smoke_p0.py │ │ └── test_smoke_p1.py │ └── test_smoke_alpha_mining.py ├── deploy/ │ ├── Dockerfile.celery │ ├── celery-entrypoint.sh │ └── docker-compose.dev.yml ├── docs/ │ ├── BochaAI_Web_Search_API_20251222_121535.md │ └── 天眼查MCP服务_20260104_171528.md ├── frontend/ │ ├── .gitignore │ ├── QUICKSTART.md │ ├── README.md │ ├── index.html │ ├── package.json │ ├── postcss.config.js │ ├── src/ │ │ ├── App.tsx │ │ ├── components/ │ │ │ ├── DebateChatRoom.tsx │ │ │ ├── DebateConfig.tsx │ │ │ ├── DebateHistorySidebar.tsx │ │ │ ├── HighlightText.tsx │ │ │ ├── KLineChart.tsx │ │ │ ├── MentionInput.tsx │ │ │ ├── ModelSelector.tsx │ │ │ ├── NewsDetailDrawer.tsx │ │ │ ├── StockSearch.tsx │ │ │ ├── alpha-mining/ │ │ │ │ ├── AgentDemo.tsx │ │ │ │ ├── MetricsDashboard.tsx │ │ │ │ ├── OperatorGrid.tsx │ │ │ │ ├── SentimentCompare.tsx │ │ │ │ ├── TrainingMonitor.tsx │ │ │ │ └── index.ts │ │ │ └── ui/ │ │ │ ├── badge.tsx │ │ │ ├── button.tsx │ │ │ ├── card.tsx │ │ │ ├── dropdown-menu.tsx │ │ │ ├── sheet.tsx │ │ │ └── tabs.tsx │ │ ├── context/ │ │ │ └── NewsToolbarContext.tsx │ │ ├── hooks/ │ │ │ └── useDebounce.ts │ │ ├── index.css │ │ ├── layout/ │ │ │ └── MainLayout.tsx │ │ ├── lib/ │ │ │ ├── api-client.ts │ │ │ └── utils.ts │ │ ├── main.tsx │ │ ├── pages/ │ │ │ ├── AgentMonitorPage.tsx │ │ │ ├── AlphaMiningPage.tsx │ │ │ ├── Dashboard.tsx │ │ │ ├── NewsListPage.tsx │ │ │ ├── StockAnalysisPage.tsx │ │ │ ├── StockSearchPage.tsx │ │ │ └── TaskManagerPage.tsx │ │ ├── store/ │ │ │ ├── useDebateStore.ts │ │ │ ├── useLanguageStore.ts │ │ │ ├── useNewsStore.ts │ │ │ └── useTaskStore.ts │ │ └── types/ │ │ └── api.ts │ ├── tailwind.config.js │ ├── tsconfig.json │ ├── tsconfig.node.json │ └── vite.config.ts ├── legacy_v1/ │ ├── .deepsource.toml │ ├── Chinese_Stop_Words.txt │ ├── Crawler/ │ │ ├── __init__.py │ │ ├── crawler_cnstock.py │ │ ├── crawler_jrj.py │ │ ├── crawler_nbd.py │ │ ├── crawler_sina.py │ │ ├── crawler_stcn.py │ │ └── crawler_tushare.py │ ├── README_OLD.md │ ├── Text_Analysis/ │ │ ├── __init__.py │ │ ├── text_mining.py │ │ └── text_processing.py │ ├── finance_dict.txt │ ├── run_crawler_cnstock.py │ ├── run_crawler_jrj.py │ ├── run_crawler_nbd.py │ ├── run_crawler_sina.py │ ├── run_crawler_stcn.py │ ├── run_crawler_tushare.py │ ├── run_main.py │ └── src/ │ ├── Gon/ │ │ ├── __init__.py │ │ ├── cnstockspyder.py │ │ ├── history_starter_cnstock.py │ │ ├── history_starter_jrj.py │ │ ├── history_starter_nbd.py │ │ ├── history_starter_stock_price.py │ │ ├── ifengspyder.py │ │ ├── jrjspyder.py │ │ ├── kill_realtime_spyder_tasks.py │ │ ├── money163spyder.py │ │ ├── nbdspyder.py │ │ ├── realtime_starter_cnstock.py │ │ ├── realtime_starter_jrj.py │ │ ├── realtime_starter_nbd.py │ │ ├── realtime_starter_redis_queue.py │ │ ├── realtime_starter_stock_price.py │ │ ├── sinaspyder.py │ │ ├── spyder.py │ │ └── stockinfospyder.py │ ├── Hisoka/ │ │ └── classifier.py │ ├── Killua/ │ │ ├── __init__.py │ │ ├── buildstocknewsdb.py │ │ ├── deduplication.py │ │ └── denull.py │ ├── Kite/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── database.py │ │ ├── log.py │ │ ├── utils.py │ │ └── webserver.py │ ├── Leorio/ │ │ ├── __init__.py │ │ ├── chnstopwords.txt │ │ ├── financedict.txt │ │ ├── tokenization.py │ │ └── topicmodelling.py │ ├── __init__.py │ ├── history_spyder_startup.bat │ ├── main.py │ ├── realtime_spyder_startup.bat │ └── realtime_spyder_stopall.bat ├── reset_all_data.sh └── thirdparty/ ├── DISC-FinLLM.md ├── ElegantRL.md ├── FinCast-fts.md ├── FinGPT.md ├── FinGenius.md ├── FinRL-Meta.md ├── FinRL.md ├── FinRobot.md ├── FinceptTerminal.md ├── Kronos.md ├── Lean.md ├── README.md ├── TradingAgents-CN.md ├── TradingAgents.md ├── TrendRadar.md ├── agentic-trading.md ├── awesome-quant.md ├── backtrader.md ├── investor-agent.md ├── panda_quantflow.md ├── qlib.md └── vnpy.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .deepsource.toml ================================================ version = 1 [[analyzers]] name = "python" [analyzers.meta] runtime_version = "3.x.x" ================================================ FILE: .gitignore ================================================ # Development documentation (local only, not for Git) devlogs/ conclusions/ researches/ # Python __pycache__/ *.py[cod] *$py.class # Virtual environments venv/ env/ ENV/ # IDE .vscode/ .idea/ *.swp # OS .DS_Store node_modules/ **/node_modules/backend/celerybeat-schedule* backend/.crawl_cache/ backend/celerybeat-schedule backend/reproduce_sina.py backend/checkpoints/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2025 Ziran Li Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # FinnewsHunter: Multi-Agent Investment Decision Platform Driven by Financial News
中文版 | English
FinnewsHunter Logo
An enterprise-grade financial news analysis system built on the [AgenticX](https://github.com/DemonDamon/AgenticX) framework, integrating real-time news streams, deep quantitative analysis, and multi-agent debate mechanisms. FinnewsHunter goes beyond traditional text classification by deploying multi-agent teams (NewsAnalyst, Researcher, etc.) to monitor multiple financial news sources in real-time, including Sina Finance, National Business Daily, Financial World, Securities Times, and more. It leverages large language models for deep interpretation, sentiment analysis, and market impact assessment, combined with knowledge graphs to mine potential investment opportunities and risks, providing decision-level alpha signals for quantitative trading. --- ## 🎯 Project Features - ✅ **AgenticX Native**: Deeply integrated with AgenticX framework, using core abstractions like Agent, Tool, and Workflow - ✅ **AgenticX Component Integration**: Direct use of AgenticX's `BailianEmbeddingProvider` and `MilvusStorage`, avoiding reinventing the wheel - ✅ **Agent-Driven**: NewsAnalyst agent automatically analyzes news sentiment and market impact - ✅ **Multi-Provider LLM Support**: Supports 5 major LLM providers (Bailian, OpenAI, DeepSeek, Kimi, Zhipu), switchable with one click in the frontend - ✅ **Batch Operations**: Supports batch selection, batch deletion, and batch analysis of news, improving operational efficiency - ✅ **Stock K-Line Analysis**: Integrated with akshare real market data, supporting daily/minute K-line multi-period display - ✅ **Intelligent Stock Search**: Supports code and name fuzzy queries, pre-loaded with 5000+ A-share data - ✅ **Complete Tech Stack**: FastAPI + PostgreSQL + Milvus + Redis + React - ✅ **Real-time Search**: Supports multi-dimensional search by title, content, stock code, with keyword highlighting - ✅ **Async Vectorization**: Background async vectorization execution, non-blocking analysis flow - ✅ **Production Ready**: One-click deployment with Docker Compose, complete logging and monitoring --- ## 🏗️ System Architecture ![FinnewsHunter Architecture](assets/images/arch-20251201.png) The system adopts a layered architecture design: - **M6 Frontend Interaction Layer**: React + TypeScript + Shadcn UI - **M1 Platform Service Layer**: FastAPI Gateway + Task Manager - **M4/M5 Agent Collaboration Layer**: AgenticX Agent + Debate Workflow - **M2/M3 Infrastructure Layer**: Crawler Service + LLM Service + Embedding - **M7-M11 Storage & Learning Layer**: PostgreSQL + Milvus + Redis + ACE Framework --- ## 🚀 Quick Start ### Prerequisites - Python 3.11+ - Docker & Docker Compose - (Optional) OpenAI API Key or local LLM - Node.js 18+ (for frontend development) ### 1. Install AgenticX ```bash cd /Users/damon/myWork/AgenticX pip install -e . ``` ### 2. Install Backend Dependencies ```bash cd FinnewsHunter/backend pip install -r requirements.txt ``` ### 3. Configure Environment Variables ```bash cd FinnewsHunter/backend cp env.example .env # Edit .env file and fill in LLM API Key and other configurations ``` **Multi-Provider LLM Configuration:** The system supports 5 LLM providers, at least one needs to be configured: | Provider | Environment Variable | Registration URL | |----------|---------------------|------------------| | Bailian (Alibaba Cloud) | `DASHSCOPE_API_KEY` | https://dashscope.console.aliyun.com/ | | OpenAI | `OPENAI_API_KEY` | https://platform.openai.com/api-keys | | DeepSeek | `DEEPSEEK_API_KEY` | https://platform.deepseek.com/ | | Kimi (Moonshot) | `MOONSHOT_API_KEY` | https://platform.moonshot.cn/ | | Zhipu | `ZHIPU_API_KEY` | https://open.bigmodel.cn/ | **Example Configuration (Recommended: Bailian):** ```bash # Bailian (Alibaba Cloud) - Recommended, fast access in China DASHSCOPE_API_KEY=sk-your-dashscope-key DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 BAILIAN_MODELS=qwen-plus,qwen-max,qwen-turbo # Optional: Other providers OPENAI_API_KEY=sk-your-openai-key DEEPSEEK_API_KEY=sk-your-deepseek-key ``` ### 4. Start Base Services (PostgreSQL, Redis, Milvus) ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml up -d postgres redis milvus-etcd milvus-minio milvus-standalone ``` ### 5. Initialize Database ```bash cd FinnewsHunter/backend python init_db.py ``` ### 5.1 Initialize Stock Data (Optional, for stock search functionality) ```bash cd FinnewsHunter/backend python -m app.scripts.init_stocks # Will fetch all A-share data (approximately 5000+ stocks) from akshare and save to database ``` ### 6. Start Backend API Service ```bash cd FinnewsHunter/backend uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` ### 7. Start Celery Worker and Beat (Auto Crawling) ```bash # Open a new terminal cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml up -d celery-worker celery-beat ``` ### 8. Start Frontend Service ```bash # Open a new terminal cd FinnewsHunter/frontend npm install # First time requires dependency installation npm run dev ``` ### 9. Access Application - **Frontend Interface**: http://localhost:3000 - **Backend API**: http://localhost:8000 - **API Documentation**: http://localhost:8000/docs --- ## 🔄 Service Management ### View All Service Status ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml ps ``` ### Restart All Services ```bash cd FinnewsHunter # Restart Docker services (infrastructure + Celery) docker compose -f deploy/docker-compose.dev.yml restart # If backend API is started independently, manually restart it # Press Ctrl+C to stop backend process, then rerun: cd backend uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` ### Restart Specific Service ```bash cd FinnewsHunter # Restart only Celery (after code changes) docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat # Restart only database docker compose -f deploy/docker-compose.dev.yml restart postgres # Restart only Redis docker compose -f deploy/docker-compose.dev.yml restart redis ``` ### Stop All Services ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml down ``` ### View Logs ```bash cd FinnewsHunter # View Celery Worker logs docker compose -f deploy/docker-compose.dev.yml logs -f celery-worker # View Celery Beat logs (scheduled task dispatch) docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # View PostgreSQL logs docker compose -f deploy/docker-compose.dev.yml logs -f postgres # View all service logs docker compose -f deploy/docker-compose.dev.yml logs -f ``` --- ## 🗑️ Reset Database ### Method 1: Use One-Click Reset Script (Recommended) ⭐ ```bash cd FinnewsHunter # Execute reset script ./reset_all_data.sh # Enter yes to confirm ``` **The script will automatically complete:** 1. ✅ Clear all news and task data in PostgreSQL 2. ✅ Clear Redis cache 3. ✅ Reset database auto-increment IDs (restart from 1) 4. ✅ Clear Celery schedule files 5. ✅ Automatically restart Celery services **After execution, wait:** - 5-10 minutes for the system to automatically re-crawl data - Access frontend to view new data --- ### Method 2: Manual Reset (Advanced) #### Step 1: Clear PostgreSQL Data ```bash # Enter PostgreSQL container docker exec -it finnews_postgres psql -U finnews -d finnews_db ``` Execute in PostgreSQL command line: ```sql -- Clear news table DELETE FROM news; -- Clear task table DELETE FROM crawl_tasks; -- Clear analysis table DELETE FROM analyses; -- Reset auto-increment IDs ALTER SEQUENCE news_id_seq RESTART WITH 1; ALTER SEQUENCE crawl_tasks_id_seq RESTART WITH 1; ALTER SEQUENCE analyses_id_seq RESTART WITH 1; -- Verify results (should all be 0) SELECT 'news table', COUNT(*) FROM news; SELECT 'crawl_tasks table', COUNT(*) FROM crawl_tasks; SELECT 'analyses table', COUNT(*) FROM analyses; -- Exit \q ``` #### Step 2: Clear Redis Cache ```bash cd FinnewsHunter docker exec finnews_redis redis-cli FLUSHDB ``` #### Step 3: Clear Celery Schedule Files ```bash cd FinnewsHunter/backend rm -f celerybeat-schedule* ``` #### Step 4: Restart Celery Services ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat ``` #### Step 5: Verify Data Cleared ```bash # Check news count (should be 0) docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT COUNT(*) FROM news;" # Check Redis (should be 0 or very small) docker exec finnews_redis redis-cli DBSIZE # Check if Celery has started crawling docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # Should see 10 crawl tasks triggered per minute ``` --- ### Method 3: Use Python Script Reset ```bash cd FinnewsHunter/backend python reset_database.py # Enter yes to confirm ``` --- ### Method 4: Quick Manual Cleanup (One-Line Commands) 🔥 **Use Case:** When reset script doesn't work, this is the fastest method ```bash cd FinnewsHunter # Step 1: Clear database tables docker exec finnews_postgres psql -U finnews -d finnews_db -c "DELETE FROM news; DELETE FROM crawl_tasks; DELETE FROM analyses;" # Step 2: Reset auto-increment IDs docker exec finnews_postgres psql -U finnews -d finnews_db -c "ALTER SEQUENCE news_id_seq RESTART WITH 1; ALTER SEQUENCE crawl_tasks_id_seq RESTART WITH 1; ALTER SEQUENCE analyses_id_seq RESTART WITH 1;" # Step 3: Clear Redis cache docker exec finnews_redis redis-cli FLUSHDB # Step 4: Clear Celery schedule files rm -f backend/celerybeat-schedule* # Step 5: Restart Celery services docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat # Step 6: Verify cleared (should display 0) docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT COUNT(*) FROM news;" ``` **Immediately refresh browser after execution:** - Mac: `Command + Shift + R` - Windows: `Ctrl + Shift + R` --- ### 🖥️ Clear Frontend Cache (Important!) **After data is cleared, frontend may still display old data due to browser cache.** #### Method 1: Hard Refresh Browser (Recommended) ⭐ **Mac System:** ``` Press Command + Shift + R or Command + Option + R ``` **Windows/Linux System:** ``` Press Ctrl + Shift + R or Ctrl + F5 ``` #### Method 2: Developer Tools Clear Cache 1. Press `F12` to open developer tools 2. Right-click the refresh button (next to address bar) 3. Select **"Empty Cache and Hard Reload"** #### Method 3: Clear Browser Cache 1. **Chrome/Edge:** - `Command + Shift + Delete` (Mac) or `Ctrl + Shift + Delete` (Windows) - Check "Cached images and files" - Time range select "All time" - Click "Clear data" 2. **After refreshing page, hard refresh again** - Ensure React Query cache is also cleared #### Method 4: Restart Frontend Dev Server (Most Thorough) ```bash # Press Ctrl+C in frontend terminal to stop service # Then restart cd FinnewsHunter/frontend npm run dev ``` --- ## 📊 Data Recovery Timeline After Reset | Time | Event | Expected Result | |------|-------|----------------| | 0 min | Execute reset script | Database cleared, Redis cleared | | 1 min | Celery Beat starts scheduling | 10 crawl tasks triggered | | 2-5 min | First batch of news saved | Database starts having data | | 5-10 min | All sources have data | Frontend can see 100+ news | | 30 min | Data continues growing | 500+ news | | 1 hour | Stable operation | 1000-2000 news | **Notes:** - Need to wait 5-10 minutes after reset to see new data - **Frontend must hard refresh** (Command+Shift+R / Ctrl+Shift+R) to clear cache - Don't reset frequently, affects system stability **Steps to immediately hard refresh frontend after reset:** 1. Execute reset command 2. **Immediately** press `Command + Shift + R` (Mac) or `Ctrl + Shift + R` (Windows) in browser 3. Wait 5-10 minutes then refresh again to view new data --- ## ⚠️ Crawler Status Check ### Check Which Sources Are Working ```bash cd FinnewsHunter # View news count by source docker exec finnews_postgres psql -U finnews -d finnews_db -c " SELECT source, COUNT(*) as count FROM news WHERE created_at > NOW() - INTERVAL '1 hour' GROUP BY source ORDER BY count DESC; " # View recent crawl task status docker exec finnews_postgres psql -U finnews -d finnews_db -c " SELECT source, crawled_count, saved_count, status, error_message FROM crawl_tasks WHERE created_at > NOW() - INTERVAL '10 minutes' ORDER BY created_at DESC LIMIT 20; " ``` ### View Crawl Errors ```bash cd FinnewsHunter # View ERROR logs docker compose -f deploy/docker-compose.dev.yml logs celery-worker | grep ERROR # View specific source issues docker compose -f deploy/docker-compose.dev.yml logs celery-worker | grep "jwview" ``` --- ## 📚 User Guide ### Auto Crawl Mode (Recommended) ⭐ **System is configured with automatic crawling for 10 news sources:** 1. 🌐 Sina Finance 2. 🐧 Tencent Finance 3. 💰 Financial World 4. 📊 Economic Observer 5. 📈 Caijing.com 6. 📉 21st Century Business Herald 7. 📰 National Business Daily 8. 🎯 Yicai 9. 📧 NetEase Finance 10. 💎 East Money **How it works:** - ✅ Celery Beat automatically triggers crawling for all sources every 1 minute - ✅ Automatic deduplication (URL level) - ✅ Smart time filtering (keep news within 24 hours) - ✅ Stock keyword filtering - ✅ No manual operation needed **View crawl progress:** ```bash # View Celery Beat scheduling logs cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # View Celery Worker execution logs docker compose -f deploy/docker-compose.dev.yml logs -f celery-worker ``` --- ### Manual Refresh (Get Latest Immediately) **Method 1: Via Frontend** 1. Visit http://localhost:3000/news 2. Click the "🔄 Refresh Now" button in the top right 3. System will immediately trigger crawling, data updates in about 2 minutes **Method 2: Via API** ```bash # Force refresh Sina Finance curl -X POST "http://localhost:8000/api/v1/news/refresh?source=sina" # Force refresh all sources (need to call individually) for source in sina tencent jwview eeo caijing jingji21 nbd yicai 163 eastmoney; do curl -X POST "http://localhost:8000/api/v1/news/refresh?source=$source" sleep 1 done ``` --- ### View News List **Method 1: Via Frontend (Recommended)** - Visit http://localhost:3000 - Homepage: View source statistics and latest news - News Feed: Filter news by source and sentiment - Batch selection support: Use checkboxes to select multiple news, supports Shift key range selection - Batch operations: Select all/deselect all, batch delete, batch analyze **Method 2: Via API** ```bash # Get latest news from all sources (200 items) curl "http://localhost:8000/api/v1/news/latest?limit=200" # Get news from specific source curl "http://localhost:8000/api/v1/news/latest?source=sina&limit=50" # Filter by sentiment (using old API) curl "http://localhost:8000/api/v1/news/?sentiment=positive&limit=20" # Get all available news source list curl "http://localhost:8000/api/v1/news/sources" ``` --- ### Batch Operations on News **Frontend Operations:** 1. **Batch Selection**: - Click checkbox on the left of news card to select single news - Hold Shift key and click for range selection - Use "Select All" button in top toolbar to select all news in current filter results - Selection state automatically clears when switching news source or filter conditions 2. **Batch Delete**: - After selecting multiple news, click "Batch Delete" button in top toolbar - After confirming delete dialog, selected news will be deleted - List automatically refreshes after deletion 3. **Batch Analysis**: - After selecting multiple news, click "Batch Analyze" button in top toolbar - System will analyze selected news sequentially, showing progress and result statistics - After analysis completes, shows success/failure count **API Operations:** ```bash # Batch delete news curl -X POST "http://localhost:8000/api/v1/news/batch/delete" \ -H "Content-Type: application/json" \ -d '{"news_ids": [1, 2, 3]}' # Batch analyze news curl -X POST "http://localhost:8000/api/v1/analysis/batch" \ -H "Content-Type: application/json" \ -d '{"news_ids": [1, 2, 3], "provider": "bailian", "model": "qwen-plus"}' ``` --- ### Analyze News **Method 1: Via Frontend** - Click "✨ Analyze" button on news card - Wait 3-5 seconds to view analysis results - Click news card to open detail drawer, view complete analysis content **Method 2: Via API** ```bash # Analyze news with specified ID (using default model) curl -X POST http://localhost:8000/api/v1/analysis/news/1 # Analyze news (specify model) curl -X POST http://localhost:8000/api/v1/analysis/news/1 \ -H "Content-Type: application/json" \ -d '{"provider": "bailian", "model": "qwen-max"}' # View analysis results curl http://localhost:8000/api/v1/analysis/1 ``` --- ### Switch LLM Model **Frontend Operations:** 1. Click model selector in top right (shows current model name) 2. Select different provider and model from dropdown menu 3. Selection automatically saves, subsequent analyses will use new model **Supported Models:** - 🔥 **Bailian**: qwen-plus, qwen-max, qwen-turbo, qwen-long - 🤖 **OpenAI**: gpt-4, gpt-4-turbo, gpt-3.5-turbo - 🧠 **DeepSeek**: deepseek-chat, deepseek-coder - 🌙 **Kimi**: moonshot-v1-8k, moonshot-v1-32k, moonshot-v1-128k - 🔮 **Zhipu**: glm-4, glm-4-plus, glm-4-air **API to Get Available Model List:** ```bash curl http://localhost:8000/api/v1/llm/config ``` --- ### Search News **Frontend Operations:** 1. Enter keywords in top search box 2. Supports search: title, content, stock code, source 3. Matching keywords will be highlighted 4. Search has 300ms debounce, automatically searches after input stops **Search Examples:** - Search stock code: `600519` (Kweichow Moutai) - Search keywords: `新能源` (new energy), `半导体` (semiconductor) - Search source: `sina`, `eastmoney` --- ### View News Details **Frontend Operations:** 1. Click any news card 2. Detail drawer slides out from right, displaying: - 📰 News title and source - 📊 Sentiment score (positive/negative/neutral) - 📈 Associated stock codes - 📝 Complete news content - 🤖 AI analysis results (Markdown format) - 🔗 Original article link 3. Click "Copy Analysis Content" to copy analysis report in Markdown format --- ### Stock K-Line Analysis **Frontend Operations:** 1. Visit http://localhost:3000/stocks/SH600519 (Kweichow Moutai example) 2. Use top right search box to enter stock code or name (e.g., `茅台` (Moutai), `600519`) 3. Select time period: Daily K, 60min, 30min, 15min, 5min, 1min 4. Chart supports: - 📈 K-line candlestick chart (OHLC) - 📊 Volume bar chart - 📉 MA moving averages (5/10/30/60 day) **API Operations:** ```bash # Get K-line data (daily, default 180 items) curl "http://localhost:8000/api/v1/stocks/SH600519/kline?period=daily&limit=180" # Get minute K-line (60-minute line) curl "http://localhost:8000/api/v1/stocks/SH600519/kline?period=60m&limit=200" # Search stocks curl "http://localhost:8000/api/v1/stocks/search/realtime?q=茅台&limit=10" # View stock count in database curl "http://localhost:8000/api/v1/stocks/count" ``` --- ### Filter by Source **Frontend Operations:** 1. **Homepage (Dashboard)** - View "News Source Statistics" card - Click any source button to filter - Display news count and list for that source 2. **News Feed Page** - Top has 10 source filter buttons - Click to switch and view different sources - Supports source + sentiment dual filtering **API Operations:** ```bash # View Sina Finance news curl "http://localhost:8000/api/v1/news/latest?source=sina&limit=50" # View National Business Daily news curl "http://localhost:8000/api/v1/news/latest?source=nbd&limit=50" # View all sources curl "http://localhost:8000/api/v1/news/latest?limit=200" ``` --- ## 🏗️ Project Structure ``` FinnewsHunter/ ├── backend/ # Backend service │ ├── app/ │ │ ├── agents/ # Agent definitions (NewsAnalyst, debate agents, etc.) │ │ ├── api/v1/ # FastAPI routes │ │ │ ├── analysis.py # Analysis API (supports batch analysis) │ │ │ ├── llm_config.py # LLM config API │ │ │ ├── news_v2.py # News API (supports batch delete) │ │ │ └── ... │ │ ├── core/ # Core configuration (config, database, redis, neo4j) │ │ ├── models/ # SQLAlchemy data models │ │ ├── services/ # Business services │ │ │ ├── llm_service.py # LLM service (multi-provider support) │ │ │ ├── analysis_service.py # Analysis service (async vectorization) │ │ │ ├── embedding_service.py # Vectorization service (based on AgenticX BailianEmbeddingProvider) │ │ │ └── stock_data_service.py # Stock data service │ │ ├── storage/ # Storage wrapper │ │ │ └── vector_storage.py # Milvus vector storage (based on AgenticX MilvusStorage) │ │ ├── tasks/ # Celery tasks │ │ └── tools/ # AgenticX tools (Crawler, Cleaner) │ ├── tests/ # Test and utility scripts │ │ ├── check_milvus_data.py # Check Milvus vector storage data │ │ ├── check_news_embedding_status.py # Check news vectorization status │ │ └── manual_vectorize.py # Manually vectorize specified news │ ├── env.example # Environment variable template │ └── requirements.txt # Python dependencies ├── frontend/ # React frontend │ └── src/ │ ├── components/ # Components │ │ ├── ModelSelector.tsx # LLM model selector │ │ ├── NewsDetailDrawer.tsx # News detail drawer │ │ └── HighlightText.tsx # Keyword highlighting │ ├── context/ # React Context │ ├── hooks/ # Custom Hooks │ │ └── useDebounce.ts # Debounce Hook │ ├── layout/ # Layout components │ └── pages/ # Page components │ └── NewsListPage.tsx # News list page (supports batch operations) ├── deploy/ # Deployment configuration │ ├── docker-compose.dev.yml # Docker Compose configuration │ ├── Dockerfile.celery # Celery image build file │ └── celery-entrypoint.sh # Celery container startup script ├── conclusions/ # Module summary documentation │ ├── backend/ # Backend module summaries │ └── frontend/ # Frontend module summaries └── .dev-docs/ # Development documentation ``` --- ## 🧪 Testing & Acceptance ### MVP Acceptance Criteria - [x] News crawling successful and saved to PostgreSQL - [x] NewsAnalyst calls LLM to complete analysis - [x] Analysis results include sentiment scores - [x] Frontend can display news and analysis results - [x] Support multi-provider LLM dynamic switching - [x] News details display complete analysis content - [x] Real-time search and filtering functionality - [x] Batch selection, batch delete, batch analysis functionality - [x] Vectorization and storage services based on AgenticX - [x] Async vectorization, non-blocking analysis flow ### Testing Process 1. **Start All Services** ```bash ./start.sh ``` 2. **Check Docker Container Status** ```bash docker ps # Should see: postgres, redis, milvus-standalone, milvus-etcd, milvus-minio ``` 3. **Test News Crawling** ```bash curl -X POST http://localhost:8000/api/v1/news/crawl \ -H "Content-Type: application/json" \ -d '{"source": "sina", "start_page": 1, "end_page": 1}' # Wait 5-10 seconds then check results curl http://localhost:8000/api/v1/news/?limit=5 ``` 4. **Test Agent Analysis** ```bash # Get first news ID NEWS_ID=$(curl -s http://localhost:8000/api/v1/news/?limit=1 | jq '.[0].id') # Trigger analysis curl -X POST http://localhost:8000/api/v1/analysis/news/$NEWS_ID # View analysis results curl http://localhost:8000/api/v1/analysis/1 ``` 5. **Test Frontend Interface** - Open `frontend/index.html` - Click "Crawl News" and wait for completion - Select a news item and click "Analyze" - Check if sentiment score is displayed --- ## 🔧 Troubleshooting ### Issue 1: Database Connection Failed **Symptom:** Backend startup error `could not connect to database` **Solution:** ```bash cd FinnewsHunter # Check if PostgreSQL is running docker ps | grep postgres # View logs docker compose -f deploy/docker-compose.dev.yml logs postgres # Restart container docker compose -f deploy/docker-compose.dev.yml restart postgres # Wait 30 seconds then retry backend startup ``` --- ### Issue 2: Celery Tasks Not Executing **Symptom:** Frontend shows 0 news count, no automatic crawling **Troubleshooting Steps:** ```bash cd FinnewsHunter # 1. Check if Celery Worker is running docker ps | grep celery # 2. View Celery Beat logs (should see tasks triggered every minute) docker compose -f deploy/docker-compose.dev.yml logs celery-beat --tail=100 # 3. View Celery Worker logs (check task execution) docker compose -f deploy/docker-compose.dev.yml logs celery-worker --tail=100 # 4. Check Redis connection docker exec finnews_redis redis-cli PING # Should return PONG # 5. Restart Celery services docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat ``` --- ### Issue 3: Crawling Failed (404 Error) **Symptom:** Celery logs show `404 Client Error: Not Found` **Cause:** News website URL has changed **Solution:** ```bash # 1. Manually visit URL to verify if available curl -I https://finance.caijing.com.cn/ # 2. If URL changed, update corresponding crawler configuration # Edit backend/app/tools/{source}_crawler.py # Update BASE_URL and STOCK_URL # 3. Clear Python cache cd FinnewsHunter/backend find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true # 4. Restart Celery cd .. docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat ``` --- ### Issue 4: Only Sina Finance Has Data **Symptom:** Other 9 sources have no news **Possible Causes:** 1. Celery Beat configuration incomplete 2. Crawler code has errors 3. Website URL incorrect **Solution:** ```bash cd FinnewsHunter # 1. Check Celery Beat configuration docker compose -f deploy/docker-compose.dev.yml logs celery-beat | grep "crawl-" # Should see 10 scheduled tasks (crawl-sina, crawl-tencent, ..., crawl-eastmoney) # 2. Manually test single source crawling docker exec -it finnews_celery_worker python -c " from app.tools import get_crawler_tool crawler = get_crawler_tool('nbd') # Test National Business Daily news = crawler.crawl() print(f'Crawled {len(news)} news items') " # 3. View data volume by source in database docker exec finnews_postgres psql -U finnews -d finnews_db -c " SELECT source, COUNT(*) as count FROM news GROUP BY source ORDER BY count DESC; " # 4. If a source keeps failing, view detailed errors docker compose -f deploy/docker-compose.dev.yml logs celery-worker | grep "ERROR" ``` --- ### Issue 5: LLM Call Failed **Symptom:** Analysis functionality not working, error `LLM Provider NOT provided` **Solution:** ```bash cd FinnewsHunter/backend # 1. Check if API Key is configured grep -E "DASHSCOPE_API_KEY|OPENAI_API_KEY|DEEPSEEK_API_KEY" .env # 2. Check if Base URL is correct (Bailian must configure) grep DASHSCOPE_BASE_URL .env # Should be: https://dashscope.aliyuncs.com/compatible-mode/v1 # 3. Verify LLM config API is normal curl http://localhost:8000/api/v1/llm/config | jq '.providers[].has_api_key' # At least one should return true # 4. If using Bailian, ensure complete configuration cat >> .env << EOF DASHSCOPE_API_KEY=sk-your-key DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 BAILIAN_MODELS=qwen-plus,qwen-max EOF # 5. Restart backend service ``` --- ### Issue 6: Frontend Shows Blank or CORS Error **Symptom:** Frontend cannot load data, browser Console shows CORS error **Solution:** ```bash # 1. Check backend CORS configuration cd FinnewsHunter/backend grep BACKEND_CORS_ORIGINS .env # Should include http://localhost:3000 # 2. Check frontend API address configuration cd ../frontend cat .env # VITE_API_URL should be http://localhost:8000 # 3. Hard refresh browser # Chrome/Edge: Ctrl+Shift+R (Windows) or Cmd+Shift+R (Mac) # 4. Restart frontend dev server npm run dev ``` --- ### Issue 7: Milvus Connection Failed **Symptom:** Vector search functionality not working **Solution:** ```bash cd FinnewsHunter # Milvus requires longer startup time (approximately 60 seconds) docker compose -f deploy/docker-compose.dev.yml logs milvus-standalone # Check health status docker inspect finnews_milvus | grep -A 10 Health # Restart Milvus related services docker compose -f deploy/docker-compose.dev.yml restart milvus-etcd milvus-minio milvus-standalone ``` --- ### Issue 8: Data Statistics Inaccurate **Symptom:** Homepage shows news count doesn't match actual **Solution:** ```bash # Use reset script to clear data and start fresh cd FinnewsHunter ./reset_all_data.sh ``` --- ### Common Debugging Commands ```bash cd FinnewsHunter # View all container status docker compose -f deploy/docker-compose.dev.yml ps # View complete logs for a service docker compose -f deploy/docker-compose.dev.yml logs celery-worker --tail=500 # Enter container for debugging docker exec -it finnews_celery_worker bash # View database connection docker exec finnews_postgres psql -U finnews -d finnews_db -c "\conninfo" # View Redis connection docker exec finnews_redis redis-cli INFO # Test network connectivity docker exec finnews_celery_worker ping -c 3 postgres ``` --- ## ⚡ Quick Reference (Common Commands) ### Project Directory ```bash cd FinnewsHunter ``` ### One-Click Operations ```bash # Start all services docker compose -f deploy/docker-compose.dev.yml up -d # Stop all services docker compose -f deploy/docker-compose.dev.yml down # Restart Celery (after code updates) docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat # Clear all data and start fresh ./reset_all_data.sh ``` ### View Status ```bash # Service status docker compose -f deploy/docker-compose.dev.yml ps # News count docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT source, COUNT(*) FROM news GROUP BY source;" # Task count docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT status, COUNT(*) FROM crawl_tasks GROUP BY status;" # Redis cache docker exec finnews_redis redis-cli DBSIZE ``` ### View Logs ```bash # Celery Beat (scheduled dispatch) docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # Celery Worker (task execution) docker compose -f deploy/docker-compose.dev.yml logs -f celery-worker # PostgreSQL docker compose -f deploy/docker-compose.dev.yml logs -f postgres # All services docker compose -f deploy/docker-compose.dev.yml logs -f ``` ### Direct Access - **Frontend**: http://localhost:3000 - **Backend API**: http://localhost:8000 - **API Documentation**: http://localhost:8000/docs --- ## 📊 Database Structure ### News Table - id, title, content, url, source - publish_time, stock_codes - sentiment_score, is_embedded ### Analysis Table - id, news_id, agent_name - sentiment, sentiment_score, confidence - analysis_result, structured_data ### Stock Table - id, code, name, industry, market --- ## 🛠️ Development Guide ### Add New Crawler 1. Inherit `BaseCrawler` class 2. Implement `crawl()` method 3. Register in `tools/__init__.py` Example: ```python # backend/app/tools/custom_crawler.py from .crawler_base import BaseCrawler class CustomCrawlerTool(BaseCrawler): name = "custom_crawler" def crawl(self, start_page, end_page): # Implement crawling logic pass ``` ### Use Enhanced Crawler (Optional) For scenarios requiring JS rendering or intelligent content extraction, use enhanced crawler: ```python from app.tools.crawler_enhanced import crawl_url, EnhancedCrawler # Quick crawl single URL article = crawl_url("https://finance.sina.com.cn/xxx", engine='auto') print(article.to_markdown()) # Get LLM message format (multimodal) llm_messages = article.to_llm_message() # Batch crawl (with cache) crawler = EnhancedCrawler(use_cache=True) articles = crawler.crawl_batch(urls, delay=1.0) ``` **Supported Engines:** - `requests`: Basic HTTP requests (default) - `playwright`: JS rendering (requires `playwright install chromium`) - `jina`: Jina Reader API (requires `JINA_API_KEY` configuration) - `auto`: Automatically select best engine **Install Optional Dependencies:** ```bash pip install markdownify readabilipy playwright playwright install chromium # Optional, for JS rendering ``` --- ### Add New Agent 1. Inherit `Agent` class 2. Define role, goal, backstory 3. Implement business methods Example: ```python # backend/app/agents/risk_analyst.py from agenticx import Agent class RiskAnalystAgent(Agent): def __init__(self, llm_provider): super().__init__( name="RiskAnalyst", role="Risk Analyst", goal="Assess investment risks", llm_provider=llm_provider ) ``` --- ### Using AgenticX Components FinnewsHunter deeply integrates AgenticX framework core components to avoid reinventing the wheel: #### 1. Embedding Service The system uses `agenticx.embeddings.BailianEmbeddingProvider` as the core embedding engine: ```python from app.services.embedding_service import EmbeddingService # Synchronous interface (for sync contexts) embedding_service = EmbeddingService() vector = embedding_service.embed_text("text content") # Asynchronous interface (recommended for async contexts) vector = await embedding_service.aembed_text("text content") # Batch processing (Provider handles internal batching) vectors = embedding_service.embed_batch(["text1", "text2", "text3"]) ``` **Features**: - Redis caching support to avoid duplicate calculations - Automatic text length limit handling (6000 characters) - Both sync and async interfaces to avoid event loop conflicts #### 2. Vector Storage (Milvus) The system uses `agenticx.storage.vectordb_storages.milvus.MilvusStorage` as the vector database: ```python from app.storage.vector_storage import VectorStorage vector_storage = VectorStorage() # Store single vector vector_storage.store_embedding( news_id=1, text="news content", embedding=[0.1, 0.2, ...] ) # Batch storage vector_storage.store_embeddings_batch([ {"news_id": 1, "text": "content1", "embedding": [...]}, {"news_id": 2, "text": "content2", "embedding": [...]} ]) # Similarity search results = vector_storage.search_similar(query_vector=[...], top_k=10) # Get statistics (with query count fallback mechanism) stats = vector_storage.get_stats() ``` **Features**: - Direct use of AgenticX MilvusStorage, no duplicate implementation - Compatibility interface for simplified calls - Query count fallback when `num_entities` is inaccurate - Async operation support to avoid blocking #### 3. Async Embedding Best Practices In async contexts (e.g., FastAPI routes), use async interfaces: ```python from app.services.embedding_service import EmbeddingService from app.storage.vector_storage import VectorStorage async def analyze_news(news_id: int, text: str): embedding_service = EmbeddingService() vector_storage = VectorStorage() # Use async interface to avoid event loop conflicts embedding = await embedding_service.aembed_text(text) # Store vector asynchronously in background (non-blocking) asyncio.create_task( vector_storage.store_embedding(news_id, text, embedding) ) # Continue with analysis logic... ``` **Notes**: - In async contexts, use `aembed_text()` instead of `embed_text()` - Embedding operations run asynchronously in background, non-blocking - Milvus `flush()` operation is optimized, not executed by default (relies on auto-flush) --- ## Multi-Agent Debate Architecture FinnewsHunter's core feature is the **bull-bear debate mechanism**, through collaboration and confrontation of multiple professional agents, deeply mining investment value and risks of individual stocks. ### Core Participants | Agent | Role | Core Responsibilities | |-------|------|---------------------| | **BullResearcher** | Bull Researcher | Mine growth potential, core positives, valuation advantages | | **BearResearcher** | Bear Researcher | Identify downside risks, negative catalysts, refute optimistic expectations | | **SearchAnalyst** | Search Analyst | Dynamically acquire data (AkShare/BochaAI/browser search) | | **InvestmentManager** | Investment Manager | Host debate, evaluate argument quality, make final decisions | ### Debate Data Flow Architecture ```mermaid graph TD subgraph Debate Initiation Manager[Investment Manager] -->|Opening Statement| Orchestrator[Debate Orchestrator] end subgraph Multi-Round Debate Orchestrator -->|Round N| Bull[Bull Researcher] Bull -->|Statement + Data Request| Orchestrator Orchestrator -->|Trigger Search| Searcher[Search Analyst] Searcher -->|Financial Data| AkShare[AkShare] Searcher -->|Real-time News| BochaAI[BochaAI] Searcher -->|Web Search| Browser[Browser Engine] AkShare --> Context[Update Context] BochaAI --> Context Browser --> Context Context --> Orchestrator Orchestrator -->|Round N| Bear[Bear Researcher] Bear -->|Statement + Data Request| Orchestrator end subgraph Final Decision Orchestrator -->|Intelligent Data Supplement| Searcher Orchestrator -->|Comprehensive Judgment| Manager Manager -->|Investment Rating| Result[Final Report] end ``` ### Dynamic Search Mechanism During debate, agents can request additional data through specific format: ``` [SEARCH: "Recent gross margin data" source:akshare] -- Get financial data from AkShare [SEARCH: "Industry competition analysis" source:bochaai] -- Search news from BochaAI [SEARCH: "Recent fund flows" source:akshare] -- Get fund flows [SEARCH: "Competitor comparison analysis"] -- Automatically select best data source ``` **Supported Data Sources:** - **AkShare**: Financial indicators, K-line market data, fund flows, institutional holdings - **BochaAI**: Real-time news search, analyst reports - **Browser Search**: Baidu News, Sogou, 360 and other multi-engine search - **Knowledge Base**: Historical news and analysis data --- ## 📈 Roadmap ### Phase 1: MVP (Completed) ✅ - [x] Project infrastructure - [x] Database models - [x] Crawler tool refactoring (10 news sources) - [x] LLM service integration - [x] NewsAnalyst agent - [x] FastAPI routes - [x] React + TypeScript frontend ### Phase 1.5: Multi-Provider LLM Support (Completed) ✅ - [x] Support 5 major LLM providers (Bailian, OpenAI, DeepSeek, Kimi, Zhipu) - [x] Frontend dynamic model switching - [x] LLM config API (`/api/v1/llm/config`) - [x] News detail drawer (complete content + AI analysis) - [x] Real-time search functionality (multi-dimensional + keyword highlighting) - [x] Markdown rendering (supports tables, code blocks) - [x] One-click copy analysis report ### Phase 1.6: Stock Analysis & Enhanced Crawler (Completed) ✅ - [x] Stock K-line charts (integrated akshare + klinecharts) - [x] Multi-period support (Daily K/60min/30min/15min/5min/1min) - [x] Stock search (code/name fuzzy query, pre-loaded 5000+ A-shares) - [x] Enhanced crawler module - [x] Multi-engine support (Requests/Playwright/Jina) - [x] Intelligent content extraction (readabilipy + heuristic algorithms) - [x] Content quality assessment and auto-retry - [x] Cache mechanism and unified Article model ### Phase 1.7: AgenticX Deep Integration & Batch Operations (Completed) ✅ - [x] Migrated to AgenticX BailianEmbeddingProvider (removed redundant batch processing logic) - [x] Migrated to AgenticX MilvusStorage (simplified storage wrapper, removed duplicate code) - [x] Async vectorization interfaces (aembed_text/aembed_batch), avoid event loop conflicts - [x] Background async vectorization, non-blocking analysis flow - [x] Milvus statistics optimization (query count fallback mechanism) - [x] Frontend batch selection functionality (checkboxes + Shift range selection) - [x] Batch delete news functionality - [x] Batch analyze news functionality (with progress display and result statistics) - [x] Docker Compose optimization (Celery image build, improved startup performance) ### Phase 2: Multi-Agent Debate (Completed) ✅ - [x] BullResearcher & BearResearcher agents - [x] SearchAnalyst search analyst (dynamic data acquisition) - [x] InvestmentManager investment manager decision - [x] Debate orchestrator (DebateOrchestrator) - [x] Dynamic search mechanism (on-demand data acquisition during debate) - [x] Three debate modes: parallel analysis, real-time debate, quick analysis - [ ] Real-time WebSocket push (in progress) - [ ] Agent execution trace visualization (in progress) ### Phase 3: Knowledge Enhancement (Planned) - [ ] Financial knowledge graph (Neo4j) - [ ] Agent memory system - [ ] GraphRetriever graph retrieval ### Phase 4: Self-Evolution (Planned) - [ ] ACE framework integration - [ ] Investment strategy Playbook - [ ] Decision effectiveness evaluation and learning --- ## 📄 License This project follows the AgenticX license. --- ## 🙏 Acknowledgments - [AgenticX](https://github.com/yourusername/AgenticX) - Multi-agent framework - [FastAPI](https://fastapi.tiangolo.com/) - Web framework - [Milvus](https://milvus.io/) - Vector database - [Alibaba Cloud Bailian](https://dashscope.console.aliyun.com/) - LLM service - [Shadcn UI](https://ui.shadcn.com/) - Frontend component library --- ## ⭐ Star History If you find this project helpful, please give it a Star ⭐️! [![Star History Chart](https://api.star-history.com/svg?repos=DemonDamon/FinnewsHunter&type=Date)](https://star-history.com/#DemonDamon/FinnewsHunter&Date) --- **Built with ❤️ using AgenticX** ================================================ FILE: README_zn.md ================================================ # FinnewsHunter:金融新闻驱动的多智能体投资决策平台
中文版 | English
FinnewsHunter Logo
基于 [AgenticX](https://github.com/DemonDamon/AgenticX) 框架构建的企业级金融新闻分析系统,融合实时新闻流、深度量化分析和多智能体辩论机制。 FinnewsHunter 不再局限于传统的文本分类,而是部署多智能体战队(NewsAnalyst, Researcher 等),实时监控新浪财经、每经网、金融界、证券时报等多源财经资讯。利用大模型进行深度解读、情感分析与市场影响评估,并结合知识图谱挖掘潜在的投资机会与风险,为量化交易提供决策级别的阿尔法信号。 --- ## 🎯 项目特色 - ✅ **AgenticX 原生**: 深度集成 AgenticX 框架,使用 Agent、Tool、Workflow 等核心抽象 - ✅ **AgenticX 组件集成**: 直接使用 AgenticX 的 `BailianEmbeddingProvider` 和 `MilvusStorage`,避免重复造轮子 - ✅ **智能体驱动**: NewsAnalyst 智能体自动分析新闻情感和市场影响 - ✅ **多厂商 LLM 支持**: 支持百炼、OpenAI、DeepSeek、Kimi、智谱 5 大厂商,前端一键切换 - ✅ **批量操作**: 支持批量选择、批量删除、批量分析新闻,提高操作效率 - ✅ **股票 K 线分析**: 集成 akshare 真实行情数据,支持日K/分K多周期展示 - ✅ **股票智能搜索**: 支持代码和名称模糊查询,预加载 5000+ A股数据 - ✅ **完整技术栈**: FastAPI + PostgreSQL + Milvus + Redis + React - ✅ **实时搜索**: 支持标题、内容、股票代码多维度搜索,关键词高亮 - ✅ **异步向量化**: 后台异步执行向量化,不阻塞分析流程 - ✅ **生产就绪**: Docker Compose 一键部署,日志、监控完备 --- ## 🏗️ 系统架构 ![FinnewsHunter Architecture](assets/images/arch-20251201.png) 系统采用分层架构设计: - **M6 前端交互层**: React + TypeScript + Shadcn UI - **M1 平台服务层**: FastAPI Gateway + Task Manager - **M4/M5 智能体协同层**: AgenticX Agent + Debate Workflow - **M2/M3 基础设施层**: Crawler Service + LLM Service + Embedding - **M7-M11 存储与学习层**: PostgreSQL + Milvus + Redis + ACE Framework --- ## 🚀 快速开始 ### 前置条件 - Python 3.11+ - Docker & Docker Compose - (可选) OpenAI API Key 或本地 LLM - Node.js 18+ (前端开发) ### 1. 安装 AgenticX ```bash cd /Users/damon/myWork/AgenticX pip install -e . ``` ### 2. 安装后端依赖 ```bash cd FinnewsHunter/backend pip install -r requirements.txt ``` ### 3. 配置环境变量 ```bash cd FinnewsHunter/backend cp env.example .env # 编辑 .env 文件,填入 LLM API Key 等配置 ``` **多厂商 LLM 配置说明:** 系统支持 5 个 LLM 厂商,至少配置一个即可使用: | 厂商 | 环境变量 | 获取地址 | |------|----------|----------| | 百炼(阿里云) | `DASHSCOPE_API_KEY` | https://dashscope.console.aliyun.com/ | | OpenAI | `OPENAI_API_KEY` | https://platform.openai.com/api-keys | | DeepSeek | `DEEPSEEK_API_KEY` | https://platform.deepseek.com/ | | Kimi(Moonshot) | `MOONSHOT_API_KEY` | https://platform.moonshot.cn/ | | 智谱 | `ZHIPU_API_KEY` | https://open.bigmodel.cn/ | **示例配置(推荐百炼):** ```bash # 百炼(阿里云)- 推荐,国内访问快 DASHSCOPE_API_KEY=sk-your-dashscope-key DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 BAILIAN_MODELS=qwen-plus,qwen-max,qwen-turbo # 可选:其他厂商 OPENAI_API_KEY=sk-your-openai-key DEEPSEEK_API_KEY=sk-your-deepseek-key ``` ### 4. 启动基础服务(PostgreSQL、Redis、Milvus) ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml up -d postgres redis milvus-etcd milvus-minio milvus-standalone ``` ### 5. 初始化数据库 ```bash cd FinnewsHunter/backend python init_db.py ``` ### 5.1 初始化股票数据(可选,用于股票搜索功能) ```bash cd FinnewsHunter/backend python -m app.scripts.init_stocks # 将从 akshare 获取全部 A 股数据(约 5000+ 只)并存入数据库 ``` ### 6. 启动后端API服务 ```bash cd FinnewsHunter/backend uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` ### 7. 启动Celery Worker和Beat(自动爬取) ```bash # 新开一个终端 cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml up -d celery-worker celery-beat ``` ### 8. 启动前端服务 ```bash # 新开一个终端 cd FinnewsHunter/frontend npm install # 首次需要安装依赖 npm run dev ``` ### 9. 访问应用 - **前端界面**: http://localhost:3000 - **后端 API**: http://localhost:8000 - **API 文档**: http://localhost:8000/docs --- ## 🔄 服务管理 ### 查看所有服务状态 ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml ps ``` ### 重启所有服务 ```bash cd FinnewsHunter # 重启Docker服务(基础设施 + Celery) docker compose -f deploy/docker-compose.dev.yml restart # 如果后端API是独立启动的,需要手动重启 # Ctrl+C 停止后端进程,然后重新运行: cd backend uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` ### 重启特定服务 ```bash cd FinnewsHunter # 只重启Celery(应用代码更改后) docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat # 只重启数据库 docker compose -f deploy/docker-compose.dev.yml restart postgres # 只重启Redis docker compose -f deploy/docker-compose.dev.yml restart redis ``` ### 停止所有服务 ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml down ``` ### 查看日志 ```bash cd FinnewsHunter # 查看Celery Worker日志 docker compose -f deploy/docker-compose.dev.yml logs -f celery-worker # 查看Celery Beat日志(定时任务调度) docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # 查看PostgreSQL日志 docker compose -f deploy/docker-compose.dev.yml logs -f postgres # 查看所有服务日志 docker compose -f deploy/docker-compose.dev.yml logs -f ``` --- ## 🗑️ 重置数据库 ### 方式1:使用一键重置脚本(推荐)⭐ ```bash cd FinnewsHunter # 执行重置脚本 ./reset_all_data.sh # 输入 yes 确认 ``` **脚本会自动完成:** 1. ✅ 清空PostgreSQL中的所有新闻和任务数据 2. ✅ 清空Redis缓存 3. ✅ 重置数据库自增ID(从1重新开始) 4. ✅ 清空Celery调度文件 5. ✅ 自动重启Celery服务 **执行后等待:** - 5-10分钟系统会自动重新爬取数据 - 访问前端查看新数据 --- ### 方式2:手动重置(高级) #### 步骤1:清空PostgreSQL数据 ```bash # 进入PostgreSQL容器 docker exec -it finnews_postgres psql -U finnews -d finnews_db ``` 在PostgreSQL命令行中执行: ```sql -- 清空新闻表 DELETE FROM news; -- 清空任务表 DELETE FROM crawl_tasks; -- 清空分析表 DELETE FROM analyses; -- 重置自增ID ALTER SEQUENCE news_id_seq RESTART WITH 1; ALTER SEQUENCE crawl_tasks_id_seq RESTART WITH 1; ALTER SEQUENCE analyses_id_seq RESTART WITH 1; -- 验证结果(应该都是0) SELECT 'news表', COUNT(*) FROM news; SELECT 'crawl_tasks表', COUNT(*) FROM crawl_tasks; SELECT 'analyses表', COUNT(*) FROM analyses; -- 退出 \q ``` #### 步骤2:清空Redis缓存 ```bash cd FinnewsHunter docker exec finnews_redis redis-cli FLUSHDB ``` #### 步骤3:清空Celery调度文件 ```bash cd FinnewsHunter/backend rm -f celerybeat-schedule* ``` #### 步骤4:重启Celery服务 ```bash cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat ``` #### 步骤5:验证数据已清空 ```bash # 检查新闻数量(应该是0) docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT COUNT(*) FROM news;" # 检查Redis(应该是0或很小) docker exec finnews_redis redis-cli DBSIZE # 查看Celery是否开始爬取 docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # 应该看到每分钟触发10个爬取任务 ``` --- ### 方式3:使用Python脚本重置 ```bash cd FinnewsHunter/backend python reset_database.py # 输入 yes 确认 ``` --- ### 方式4:快速手动清理(一行命令)🔥 **适用场景:** 当重置脚本不工作时,使用此方法最快速 ```bash cd FinnewsHunter # 步骤1:清空数据库表 docker exec finnews_postgres psql -U finnews -d finnews_db -c "DELETE FROM news; DELETE FROM crawl_tasks; DELETE FROM analyses;" # 步骤2:重置自增ID docker exec finnews_postgres psql -U finnews -d finnews_db -c "ALTER SEQUENCE news_id_seq RESTART WITH 1; ALTER SEQUENCE crawl_tasks_id_seq RESTART WITH 1; ALTER SEQUENCE analyses_id_seq RESTART WITH 1;" # 步骤3:清空Redis缓存 docker exec finnews_redis redis-cli FLUSHDB # 步骤4:清空Celery调度文件 rm -f backend/celerybeat-schedule* # 步骤5:重启Celery服务 docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat # 步骤6:验证是否清空(应该显示0) docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT COUNT(*) FROM news;" ``` **执行后立即刷新浏览器:** - Mac: `Command + Shift + R` - Windows: `Ctrl + Shift + R` --- ### 🖥️ 清除前端缓存(重要!) **数据清空后,前端可能仍显示旧数据,这是因为浏览器缓存。** #### 方法1:硬刷新浏览器(推荐)⭐ **Mac系统:** ``` 按 Command + Shift + R 或 Command + Option + R ``` **Windows/Linux系统:** ``` 按 Ctrl + Shift + R 或 Ctrl + F5 ``` #### 方法2:开发者工具清空缓存 1. 按 `F12` 打开开发者工具 2. 右键点击刷新按钮(地址栏旁边) 3. 选择 **"清空缓存并硬性重新加载"** #### 方法3:清除浏览器缓存 1. **Chrome/Edge:** - `Command + Shift + Delete` (Mac) 或 `Ctrl + Shift + Delete` (Windows) - 勾选"缓存的图片和文件" - 时间范围选择"全部" - 点击"清除数据" 2. **刷新页面后,再次硬刷新** - 确保React Query缓存也被清除 #### 方法4:重启前端开发服务器(最彻底) ```bash # 在前端终端按 Ctrl+C 停止服务 # 然后重新启动 cd FinnewsHunter/frontend npm run dev ``` --- ## 📊 重置后的数据恢复时间线 | 时间 | 事件 | 预期结果 | |------|------|----------| | 0分钟 | 执行重置脚本 | 数据库清空,Redis清空 | | 1分钟 | Celery Beat开始调度 | 10个爬取任务被触发 | | 2-5分钟 | 第一批新闻保存 | 数据库开始有数据 | | 5-10分钟 | 所有源都有数据 | 前端可看到100+条新闻 | | 30分钟 | 数据持续增长 | 500+条新闻 | | 1小时 | 稳定运行 | 1000-2000条新闻 | **注意:** - 重置后需要等待5-10分钟才能看到新数据 - **前端必须硬刷新**(Command+Shift+R / Ctrl+Shift+R)清除缓存 - 不要频繁重置,会影响系统稳定性 **重置后立即硬刷新前端的步骤:** 1. 执行重置命令 2. **立即**在浏览器按 `Command + Shift + R` (Mac) 或 `Ctrl + Shift + R` (Windows) 3. 等待5-10分钟后再次刷新查看新数据 --- ## ⚠️ 爬虫状态检查 ### 查看哪些源正常工作 ```bash cd FinnewsHunter # 查看各源的新闻数量 docker exec finnews_postgres psql -U finnews -d finnews_db -c " SELECT source, COUNT(*) as count FROM news WHERE created_at > NOW() - INTERVAL '1 hour' GROUP BY source ORDER BY count DESC; " # 查看最近的爬取任务状态 docker exec finnews_postgres psql -U finnews -d finnews_db -c " SELECT source, crawled_count, saved_count, status, error_message FROM crawl_tasks WHERE created_at > NOW() - INTERVAL '10 minutes' ORDER BY created_at DESC LIMIT 20; " ``` ### 查看爬取错误 ```bash cd FinnewsHunter # 查看ERROR日志 docker compose -f deploy/docker-compose.dev.yml logs celery-worker | grep ERROR # 查看特定源的问题 docker compose -f deploy/docker-compose.dev.yml logs celery-worker | grep "jwview" ``` --- ## 📚 使用指南 ### 自动爬取模式(推荐)⭐ **系统已配置10个新闻源的自动爬取:** 1. 🌐 新浪财经 2. 🐧 腾讯财经 3. 💰 金融界 4. 📊 经济观察网 5. 📈 财经网 6. 📉 21经济网 7. 📰 每日经济新闻 8. 🎯 第一财经 9. 📧 网易财经 10. 💎 东方财富 **工作方式:** - ✅ Celery Beat 每1分钟自动触发所有源的爬取 - ✅ 自动去重(URL级别) - ✅ 智能时间筛选(保留24小时内新闻) - ✅ 股票关键词筛选 - ✅ 无需手动操作 **查看爬取进度:** ```bash # 查看Celery Beat调度日志 cd FinnewsHunter docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # 查看Celery Worker执行日志 docker compose -f deploy/docker-compose.dev.yml logs -f celery-worker ``` --- ### 手动刷新(立即获取最新) **方式 1: 通过前端** 1. 访问 http://localhost:3000/news 2. 点击右上角"🔄 立即刷新"按钮 3. 系统会立即触发爬取,约2分钟后数据更新 **方式 2: 通过 API** ```bash # 强制刷新新浪财经 curl -X POST "http://localhost:8000/api/v1/news/refresh?source=sina" # 强制刷新所有源(需要逐个调用) for source in sina tencent jwview eeo caijing jingji21 nbd yicai 163 eastmoney; do curl -X POST "http://localhost:8000/api/v1/news/refresh?source=$source" sleep 1 done ``` --- ### 查看新闻列表 **方式 1: 通过前端(推荐)** - 访问 http://localhost:3000 - 首页:查看来源统计和最新新闻 - 新闻流:按来源和情感筛选新闻 - 支持批量选择:使用复选框选择多条新闻,支持 Shift 键范围选择 - 批量操作:全选/取消全选、批量删除、批量分析 **方式 2: 通过 API** ```bash # 获取所有来源的最新新闻(200条) curl "http://localhost:8000/api/v1/news/latest?limit=200" # 获取特定来源的新闻 curl "http://localhost:8000/api/v1/news/latest?source=sina&limit=50" # 按情感筛选(使用旧接口) curl "http://localhost:8000/api/v1/news/?sentiment=positive&limit=20" # 获取所有可用的新闻源列表 curl "http://localhost:8000/api/v1/news/sources" ``` --- ### 批量操作新闻 **前端操作:** 1. **批量选择**: - 点击新闻卡片左侧的复选框选择单条新闻 - 按住 Shift 键点击可进行范围选择 - 使用顶部工具栏的"全选"按钮选择当前筛选结果的所有新闻 - 切换新闻源或筛选条件时,选择状态会自动清空 2. **批量删除**: - 选择多条新闻后,点击顶部工具栏的"批量删除"按钮 - 确认删除对话框后,选中的新闻将被删除 - 删除后会自动刷新列表 3. **批量分析**: - 选择多条新闻后,点击顶部工具栏的"批量分析"按钮 - 系统会依次分析选中的新闻,显示进度和结果统计 - 分析完成后会显示成功/失败数量 **API 操作:** ```bash # 批量删除新闻 curl -X POST "http://localhost:8000/api/v1/news/batch/delete" \ -H "Content-Type: application/json" \ -d '{"news_ids": [1, 2, 3]}' # 批量分析新闻 curl -X POST "http://localhost:8000/api/v1/analysis/batch" \ -H "Content-Type: application/json" \ -d '{"news_ids": [1, 2, 3], "provider": "bailian", "model": "qwen-plus"}' ``` --- ### 分析新闻 **方式 1: 通过前端** - 在新闻卡片上点击"✨ 分析"按钮 - 等待3-5秒查看分析结果 - 点击新闻卡片打开详情抽屉,查看完整分析内容 **方式 2: 通过 API** ```bash # 分析指定ID的新闻(使用默认模型) curl -X POST http://localhost:8000/api/v1/analysis/news/1 # 分析新闻(指定模型) curl -X POST http://localhost:8000/api/v1/analysis/news/1 \ -H "Content-Type: application/json" \ -d '{"provider": "bailian", "model": "qwen-max"}' # 查看分析结果 curl http://localhost:8000/api/v1/analysis/1 ``` --- ### 切换 LLM 模型 **前端操作:** 1. 点击右上角的模型选择器(显示当前模型名称) 2. 在下拉菜单中选择不同的厂商和模型 3. 选择后自动保存,后续分析将使用新模型 **支持的模型:** - 🔥 **百炼**: qwen-plus, qwen-max, qwen-turbo, qwen-long - 🤖 **OpenAI**: gpt-4, gpt-4-turbo, gpt-3.5-turbo - 🧠 **DeepSeek**: deepseek-chat, deepseek-coder - 🌙 **Kimi**: moonshot-v1-8k, moonshot-v1-32k, moonshot-v1-128k - 🔮 **智谱**: glm-4, glm-4-plus, glm-4-air **API 获取可用模型列表:** ```bash curl http://localhost:8000/api/v1/llm/config ``` --- ### 搜索新闻 **前端操作:** 1. 在顶部搜索框输入关键词 2. 支持搜索:标题、内容、股票代码、来源 3. 匹配的关键词会高亮显示 4. 搜索带有 300ms 防抖,输入停止后自动搜索 **搜索示例:** - 搜索股票代码:`600519`(贵州茅台) - 搜索关键词:`新能源`、`半导体` - 搜索来源:`sina`、`eastmoney` --- ### 查看新闻详情 **前端操作:** 1. 点击任意新闻卡片 2. 右侧滑出详情抽屉,展示: - 📰 新闻标题和来源 - 📊 情感评分(利好/利空/中性) - 📈 关联股票代码 - 📝 完整新闻内容 - 🤖 AI 分析结果(Markdown 格式) - 🔗 原文链接 3. 点击"复制分析内容"可复制 Markdown 格式的分析报告 --- ### 股票 K 线分析 **前端操作:** 1. 访问 http://localhost:3000/stocks/SH600519(贵州茅台示例) 2. 使用右上角搜索框输入股票代码或名称(如 `茅台`、`600519`) 3. 选择时间周期:日K、60分、30分、15分、5分、1分 4. 图表支持: - 📈 K 线蜡烛图(OHLC) - 📊 成交量柱状图 - 📉 MA 均线(5/10/30/60日) **API 操作:** ```bash # 获取 K 线数据(日线,默认180条) curl "http://localhost:8000/api/v1/stocks/SH600519/kline?period=daily&limit=180" # 获取分钟 K 线(60分钟线) curl "http://localhost:8000/api/v1/stocks/SH600519/kline?period=60m&limit=200" # 搜索股票 curl "http://localhost:8000/api/v1/stocks/search/realtime?q=茅台&limit=10" # 查看数据库中的股票数量 curl "http://localhost:8000/api/v1/stocks/count" ``` --- ### 按来源筛选查看 **前端操作:** 1. **首页(Dashboard)** - 查看"新闻来源统计"卡片 - 点击任意来源按钮筛选 - 显示该来源的新闻数量和列表 2. **新闻流页面** - 顶部有10个来源筛选按钮 - 点击切换查看不同来源 - 支持来源+情感双重筛选 **API操作:** ```bash # 查看新浪财经的新闻 curl "http://localhost:8000/api/v1/news/latest?source=sina&limit=50" # 查看每日经济新闻 curl "http://localhost:8000/api/v1/news/latest?source=nbd&limit=50" # 查看所有来源 curl "http://localhost:8000/api/v1/news/latest?limit=200" ``` --- ## 🏗️ 项目结构 ``` FinnewsHunter/ ├── backend/ # 后端服务 │ ├── app/ │ │ ├── agents/ # 智能体定义(NewsAnalyst、辩论智能体等) │ │ ├── api/v1/ # FastAPI 路由 │ │ │ ├── analysis.py # 分析 API(支持批量分析) │ │ │ ├── llm_config.py # LLM 配置 API │ │ │ ├── news_v2.py # 新闻 API(支持批量删除) │ │ │ └── ... │ │ ├── core/ # 核心配置(config, database, redis, neo4j) │ │ ├── models/ # SQLAlchemy 数据模型 │ │ ├── services/ # 业务服务 │ │ │ ├── llm_service.py # LLM 服务(支持多厂商) │ │ │ ├── analysis_service.py # 分析服务(异步向量化) │ │ │ ├── embedding_service.py # 向量化服务(基于 AgenticX BailianEmbeddingProvider) │ │ │ └── stock_data_service.py # 股票数据服务 │ │ ├── storage/ # 存储封装 │ │ │ └── vector_storage.py # Milvus 向量存储(基于 AgenticX MilvusStorage) │ │ ├── tasks/ # Celery 任务 │ │ └── tools/ # AgenticX 工具(Crawler, Cleaner) │ ├── tests/ # 测试和工具脚本 │ │ ├── check_milvus_data.py # 检查 Milvus 向量存储数据 │ │ ├── check_news_embedding_status.py # 检查新闻向量化状态 │ │ └── manual_vectorize.py # 手动向量化指定新闻 │ ├── env.example # 环境变量模板 │ └── requirements.txt # Python 依赖 ├── frontend/ # React 前端 │ └── src/ │ ├── components/ # 组件 │ │ ├── ModelSelector.tsx # LLM 模型选择器 │ │ ├── NewsDetailDrawer.tsx # 新闻详情抽屉 │ │ └── HighlightText.tsx # 关键词高亮 │ ├── context/ # React Context │ ├── hooks/ # 自定义 Hooks │ │ └── useDebounce.ts # 防抖 Hook │ ├── layout/ # 布局组件 │ └── pages/ # 页面组件 │ └── NewsListPage.tsx # 新闻列表页面(支持批量操作) ├── deploy/ # 部署配置 │ ├── docker-compose.dev.yml # Docker Compose 配置 │ ├── Dockerfile.celery # Celery 镜像构建文件 │ └── celery-entrypoint.sh # Celery 容器启动脚本 ├── conclusions/ # 模块摘要文档 │ ├── backend/ # 后端模块总结 │ └── frontend/ # 前端模块总结 └── .dev-docs/ # 开发文档 ``` --- ## 🧪 测试与验收 ### MVP 验收标准 - [x] 新闻爬取成功并存入 PostgreSQL - [x] NewsAnalyst 调用 LLM 完成分析 - [x] 分析结果包含情感评分 - [x] 前端能够展示新闻和分析结果 - [x] 支持多厂商 LLM 动态切换 - [x] 新闻详情展示完整分析内容 - [x] 实时搜索和筛选功能 - [x] 批量选择、批量删除、批量分析功能 - [x] 基于 AgenticX 的向量化和存储服务 - [x] 异步向量化,不阻塞分析流程 ### 测试流程 1. **启动所有服务** ```bash ./start.sh ``` 2. **检查 Docker 容器状态** ```bash docker ps # 应看到: postgres, redis, milvus-standalone, milvus-etcd, milvus-minio ``` 3. **测试新闻爬取** ```bash curl -X POST http://localhost:8000/api/v1/news/crawl \ -H "Content-Type: application/json" \ -d '{"source": "sina", "start_page": 1, "end_page": 1}' # 等待 5-10 秒后查看结果 curl http://localhost:8000/api/v1/news/?limit=5 ``` 4. **测试智能体分析** ```bash # 获取第一条新闻的ID NEWS_ID=$(curl -s http://localhost:8000/api/v1/news/?limit=1 | jq '.[0].id') # 触发分析 curl -X POST http://localhost:8000/api/v1/analysis/news/$NEWS_ID # 查看分析结果 curl http://localhost:8000/api/v1/analysis/1 ``` 5. **测试前端界面** - 打开 `frontend/index.html` - 点击"爬取新闻"并等待完成 - 选择一条新闻点击"分析" - 查看情感评分是否显示 --- ## 🔧 故障排查 ### 问题 1: 数据库连接失败 **症状:** 后端启动报错 `could not connect to database` **解决方法:** ```bash cd FinnewsHunter # 检查 PostgreSQL 是否启动 docker ps | grep postgres # 查看日志 docker compose -f deploy/docker-compose.dev.yml logs postgres # 重启容器 docker compose -f deploy/docker-compose.dev.yml restart postgres # 等待30秒后重试后端启动 ``` --- ### 问题 2: Celery任务不执行 **症状:** 前端显示新闻数量为0,没有自动爬取 **排查步骤:** ```bash cd FinnewsHunter # 1. 检查Celery Worker是否运行 docker ps | grep celery # 2. 查看Celery Beat日志(应该看到每分钟触发任务) docker compose -f deploy/docker-compose.dev.yml logs celery-beat --tail=100 # 3. 查看Celery Worker日志(查看任务执行情况) docker compose -f deploy/docker-compose.dev.yml logs celery-worker --tail=100 # 4. 检查Redis连接 docker exec finnews_redis redis-cli PING # 应该返回 PONG # 5. 重启Celery服务 docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat ``` --- ### 问题 3: 爬取失败(404错误) **症状:** Celery日志显示 `404 Client Error: Not Found` **原因:** 新闻网站URL已变更 **解决方法:** ```bash # 1. 手动访问URL验证是否可用 curl -I https://finance.caijing.com.cn/ # 2. 如果URL变更,更新对应爬虫的配置 # 编辑 backend/app/tools/{source}_crawler.py # 更新 BASE_URL 和 STOCK_URL # 3. 清理Python缓存 cd FinnewsHunter/backend find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true # 4. 重启Celery cd .. docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat ``` --- ### 问题 4: 只有新浪财经有数据 **症状:** 其他9个来源没有新闻 **可能原因:** 1. Celery Beat配置不完整 2. 爬虫代码有错误 3. 网站URL不正确 **解决方法:** ```bash cd FinnewsHunter # 1. 检查Celery Beat配置 docker compose -f deploy/docker-compose.dev.yml logs celery-beat | grep "crawl-" # 应该看到10个定时任务(crawl-sina, crawl-tencent, ..., crawl-eastmoney) # 2. 手动测试单个源的爬取 docker exec -it finnews_celery_worker python -c " from app.tools import get_crawler_tool crawler = get_crawler_tool('nbd') # 测试每日经济新闻 news = crawler.crawl() print(f'爬取到 {len(news)} 条新闻') " # 3. 查看数据库中各源的数据量 docker exec finnews_postgres psql -U finnews -d finnews_db -c " SELECT source, COUNT(*) as count FROM news GROUP BY source ORDER BY count DESC; " # 4. 如果某个源一直失败,查看详细错误 docker compose -f deploy/docker-compose.dev.yml logs celery-worker | grep "ERROR" ``` --- ### 问题 5: LLM 调用失败 **症状:** 分析功能不工作,报错 `LLM Provider NOT provided` **解决方法:** ```bash cd FinnewsHunter/backend # 1. 检查 API Key 是否配置 grep -E "DASHSCOPE_API_KEY|OPENAI_API_KEY|DEEPSEEK_API_KEY" .env # 2. 检查 Base URL 是否正确(百炼必须配置) grep DASHSCOPE_BASE_URL .env # 应该是: https://dashscope.aliyuncs.com/compatible-mode/v1 # 3. 验证 LLM 配置 API 是否正常 curl http://localhost:8000/api/v1/llm/config | jq '.providers[].has_api_key' # 至少有一个返回 true # 4. 如果使用百炼,确保配置完整 cat >> .env << EOF DASHSCOPE_API_KEY=sk-your-key DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 BAILIAN_MODELS=qwen-plus,qwen-max EOF # 5. 重启后端服务 ``` --- ### 问题 6: 前端显示空白或CORS错误 **症状:** 前端无法加载数据,浏览器Console显示CORS错误 **解决方法:** ```bash # 1. 检查后端CORS配置 cd FinnewsHunter/backend grep BACKEND_CORS_ORIGINS .env # 应该包含 http://localhost:3000 # 2. 检查前端API地址配置 cd ../frontend cat .env # VITE_API_URL 应该是 http://localhost:8000 # 3. 硬刷新浏览器 # Chrome/Edge: Ctrl+Shift+R (Windows) 或 Cmd+Shift+R (Mac) # 4. 重启前端开发服务器 npm run dev ``` --- ### 问题 7: Milvus 连接失败 **症状:** 向量搜索功能不工作 **解决方法:** ```bash cd FinnewsHunter # Milvus 需要较长启动时间(约 60 秒) docker compose -f deploy/docker-compose.dev.yml logs milvus-standalone # 检查健康状态 docker inspect finnews_milvus | grep -A 10 Health # 重启Milvus相关服务 docker compose -f deploy/docker-compose.dev.yml restart milvus-etcd milvus-minio milvus-standalone ``` --- ### 问题 8: 数据统计不准确 **症状:** 首页显示的新闻数和实际不符 **解决方法:** ```bash # 使用重置脚本清空数据重新开始 cd FinnewsHunter ./reset_all_data.sh ``` --- ### 常用调试命令 ```bash cd FinnewsHunter # 查看所有容器状态 docker compose -f deploy/docker-compose.dev.yml ps # 查看某个服务的完整日志 docker compose -f deploy/docker-compose.dev.yml logs celery-worker --tail=500 # 进入容器调试 docker exec -it finnews_celery_worker bash # 查看数据库连接 docker exec finnews_postgres psql -U finnews -d finnews_db -c "\conninfo" # 查看Redis连接 docker exec finnews_redis redis-cli INFO # 测试网络连通性 docker exec finnews_celery_worker ping -c 3 postgres ``` --- ## ⚡ 快速参考(常用命令) ### 项目目录 ```bash cd FinnewsHunter ``` ### 一键操作 ```bash # 启动所有服务 docker compose -f deploy/docker-compose.dev.yml up -d # 停止所有服务 docker compose -f deploy/docker-compose.dev.yml down # 重启Celery(代码更新后) docker compose -f deploy/docker-compose.dev.yml restart celery-worker celery-beat # 清空所有数据重新开始 ./reset_all_data.sh ``` ### 查看状态 ```bash # 服务状态 docker compose -f deploy/docker-compose.dev.yml ps # 新闻数量 docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT source, COUNT(*) FROM news GROUP BY source;" # 任务数量 docker exec finnews_postgres psql -U finnews -d finnews_db -c "SELECT status, COUNT(*) FROM crawl_tasks GROUP BY status;" # Redis缓存 docker exec finnews_redis redis-cli DBSIZE ``` ### 查看日志 ```bash # Celery Beat(定时调度) docker compose -f deploy/docker-compose.dev.yml logs -f celery-beat # Celery Worker(任务执行) docker compose -f deploy/docker-compose.dev.yml logs -f celery-worker # PostgreSQL docker compose -f deploy/docker-compose.dev.yml logs -f postgres # 所有服务 docker compose -f deploy/docker-compose.dev.yml logs -f ``` ### 直接访问 - **前端**: http://localhost:3000 - **后端API**: http://localhost:8000 - **API文档**: http://localhost:8000/docs --- ## 📊 数据库结构 ### News(新闻表) - id, title, content, url, source - publish_time, stock_codes - sentiment_score, is_embedded ### Analysis(分析表) - id, news_id, agent_name - sentiment, sentiment_score, confidence - analysis_result, structured_data ### Stock(股票表) - id, code, name, industry, market --- ## 🛠️ 开发指南 ### 添加新的爬虫 1. 继承 `BaseCrawler` 类 2. 实现 `crawl()` 方法 3. 注册到 `tools/__init__.py` 示例: ```python # backend/app/tools/custom_crawler.py from .crawler_base import BaseCrawler class CustomCrawlerTool(BaseCrawler): name = "custom_crawler" def crawl(self, start_page, end_page): # 实现爬取逻辑 pass ``` ### 使用增强版爬虫(可选) 对于需要 JS 渲染或智能内容提取的场景,可使用增强版爬虫: ```python from app.tools.crawler_enhanced import crawl_url, EnhancedCrawler # 快速爬取单个 URL article = crawl_url("https://finance.sina.com.cn/xxx", engine='auto') print(article.to_markdown()) # 获取 LLM 消息格式(多模态) llm_messages = article.to_llm_message() # 批量爬取(带缓存) crawler = EnhancedCrawler(use_cache=True) articles = crawler.crawl_batch(urls, delay=1.0) ``` **支持的引擎:** - `requests`: 基础 HTTP 请求(默认) - `playwright`: JS 渲染(需安装 `playwright install chromium`) - `jina`: Jina Reader API(需配置 `JINA_API_KEY`) - `auto`: 自动选择最佳引擎 **安装可选依赖:** ```bash pip install markdownify readabilipy playwright playwright install chromium # 可选,用于 JS 渲染 ``` --- ### 添加新的智能体 1. 继承 `Agent` 类 2. 定义 role、goal、backstory 3. 实现业务方法 示例: ```python # backend/app/agents/risk_analyst.py from agenticx import Agent class RiskAnalystAgent(Agent): def __init__(self, llm_provider): super().__init__( name="RiskAnalyst", role="风险分析师", goal="评估投资风险", llm_provider=llm_provider ) ``` --- ### 使用 AgenticX 组件 FinnewsHunter 深度集成了 AgenticX 框架的核心组件,避免重复造轮子: #### 1. 向量化服务(Embedding) 系统使用 `agenticx.embeddings.BailianEmbeddingProvider` 作为核心向量化引擎: ```python from app.services.embedding_service import EmbeddingService # 同步接口(适用于同步上下文) embedding_service = EmbeddingService() vector = embedding_service.embed_text("文本内容") # 异步接口(推荐在异步上下文中使用) vector = await embedding_service.aembed_text("文本内容") # 批量处理(Provider 内部已实现批量优化) vectors = embedding_service.embed_batch(["文本1", "文本2", "文本3"]) ``` **特点**: - 支持 Redis 缓存,避免重复计算 - 自动处理文本长度限制(6000字符) - 支持同步和异步两种接口,避免事件循环冲突 #### 2. 向量存储(Milvus) 系统使用 `agenticx.storage.vectordb_storages.milvus.MilvusStorage` 作为向量数据库: ```python from app.storage.vector_storage import VectorStorage vector_storage = VectorStorage() # 存储单个向量 vector_storage.store_embedding( news_id=1, text="新闻内容", embedding=[0.1, 0.2, ...] ) # 批量存储 vector_storage.store_embeddings_batch([ {"news_id": 1, "text": "内容1", "embedding": [...]}, {"news_id": 2, "text": "内容2", "embedding": [...]} ]) # 相似度搜索 results = vector_storage.search_similar(query_vector=[...], top_k=10) # 获取统计信息(带查询计数回退机制) stats = vector_storage.get_stats() ``` **特点**: - 直接使用 AgenticX MilvusStorage,无需重复实现 - 提供兼容性接口,简化调用 - 当 `num_entities` 不准确时,通过实际查询获取真实数量 - 支持异步操作,避免阻塞 #### 3. 异步向量化最佳实践 在异步上下文中(如 FastAPI 路由),推荐使用异步接口: ```python from app.services.embedding_service import EmbeddingService from app.storage.vector_storage import VectorStorage async def analyze_news(news_id: int, text: str): embedding_service = EmbeddingService() vector_storage = VectorStorage() # 使用异步接口,避免事件循环冲突 embedding = await embedding_service.aembed_text(text) # 后台异步存储向量(不阻塞分析流程) asyncio.create_task( vector_storage.store_embedding(news_id, text, embedding) ) # 继续执行分析逻辑... ``` **注意事项**: - 在异步上下文中,使用 `aembed_text()` 而不是 `embed_text()` - 向量化操作在后台异步执行,不阻塞主流程 - Milvus 的 `flush()` 操作已优化,默认不执行(依赖自动刷新) --- ## 多智能体辩论架构 FinnewsHunter 的核心特色是 **多空辩论机制**,通过多个专业智能体的协作与对抗,深度挖掘个股的投资价值和风险。 ### 核心参与角色 | 智能体 | 角色定位 | 核心职责 | |--------|----------|----------| | **BullResearcher** | 看多研究员 | 挖掘增长潜力、核心利好、估值优势 | | **BearResearcher** | 看空研究员 | 识别下行风险、负面催化剂、反驳乐观预期 | | **SearchAnalyst** | 搜索分析师 | 动态获取数据(AkShare/BochaAI/浏览器搜索) | | **InvestmentManager** | 投资经理 | 主持辩论、评估论点质量、做出最终决策 | ### 辩论数据流架构 ```mermaid graph TD subgraph 辩论启动 Manager[投资经理] -->|开场陈述| Orchestrator[辩论编排器] end subgraph 多轮辩论 Orchestrator -->|第N轮| Bull[看多研究员] Bull -->|发言 + 数据请求| Orchestrator Orchestrator -->|触发搜索| Searcher[搜索分析师] Searcher -->|财务数据| AkShare[AkShare] Searcher -->|实时新闻| BochaAI[BochaAI] Searcher -->|网页搜索| Browser[浏览器引擎] AkShare --> Context[更新上下文] BochaAI --> Context Browser --> Context Context --> Orchestrator Orchestrator -->|第N轮| Bear[看空研究员] Bear -->|发言 + 数据请求| Orchestrator end subgraph 最终决策 Orchestrator -->|智能数据补充| Searcher Orchestrator -->|综合判断| Manager Manager -->|投资评级| Result[最终报告] end ``` ### 动态搜索机制 辩论过程中,智能体可以通过特定格式请求额外数据: ``` [SEARCH: "最近的毛利率数据" source:akshare] -- 从 AkShare 获取财务数据 [SEARCH: "行业竞争格局分析" source:bochaai] -- 从 BochaAI 搜索新闻 [SEARCH: "近期资金流向" source:akshare] -- 获取资金流向 [SEARCH: "竞品对比分析"] -- 自动选择最佳数据源 ``` **支持的数据源:** - **AkShare**: 财务指标、K线行情、资金流向、机构持仓 - **BochaAI**: 实时新闻搜索、分析师报告 - **浏览器搜索**: 百度资讯、搜狗、360等多引擎搜索 - **知识库**: 历史新闻和分析数据 --- ## 📈 路线图 ### Phase 1: MVP(已完成) ✅ - [x] 项目基础设施 - [x] 数据库模型 - [x] 爬虫工具重构(10个新闻源) - [x] LLM 服务集成 - [x] NewsAnalyst 智能体 - [x] FastAPI 路由 - [x] React + TypeScript 前端 ### Phase 1.5: 多厂商 LLM 支持(已完成) ✅ - [x] 支持 5 大 LLM 厂商(百炼、OpenAI、DeepSeek、Kimi、智谱) - [x] 前端动态模型切换 - [x] LLM 配置 API(`/api/v1/llm/config`) - [x] 新闻详情抽屉(完整内容 + AI 分析) - [x] 实时搜索功能(多维度 + 关键词高亮) - [x] Markdown 渲染(支持表格、代码块) - [x] 一键复制分析报告 ### Phase 1.6: 股票分析与增强爬虫(已完成) ✅ - [x] 股票 K 线图(集成 akshare + klinecharts) - [x] 多周期支持(日K/60分/30分/15分/5分/1分) - [x] 股票搜索(代码/名称模糊查询,预加载 5000+ A股) - [x] 增强版爬虫模块 - [x] 多引擎支持(Requests/Playwright/Jina) - [x] 智能内容提取(readabilipy + 启发式算法) - [x] 内容质量评估与自动重试 - [x] 缓存机制和统一 Article 模型 ### Phase 1.7: AgenticX 深度集成与批量操作(已完成) ✅ - [x] 迁移到 AgenticX BailianEmbeddingProvider(移除冗余批量处理逻辑) - [x] 迁移到 AgenticX MilvusStorage(简化存储封装,移除重复代码) - [x] 异步向量化接口(aembed_text/aembed_batch),避免事件循环冲突 - [x] 后台异步向量化,不阻塞分析流程 - [x] Milvus 统计信息优化(查询计数回退机制) - [x] 前端批量选择功能(复选框 + Shift 范围选择) - [x] 批量删除新闻功能 - [x] 批量分析新闻功能(带进度显示和结果统计) - [x] Docker Compose 优化(Celery 镜像构建,提升启动性能) ### Phase 2: 多智能体辩论(已完成) ✅ - [x] BullResearcher & BearResearcher 智能体 - [x] SearchAnalyst 搜索分析师(动态数据获取) - [x] InvestmentManager 投资经理决策 - [x] 辩论编排器(DebateOrchestrator) - [x] 动态搜索机制(辩论中按需获取数据) - [x] 三种辩论模式:并行分析、实时辩论、快速分析 - [ ] 实时 WebSocket 推送(进行中) - [ ] 智能体执行轨迹可视化(进行中) ### Phase 3: 知识增强(计划中) - [ ] 金融知识图谱(Neo4j) - [ ] 智能体记忆系统 - [ ] GraphRetriever 图检索 ### Phase 4: 自我进化(计划中) - [ ] ACE 框架集成 - [ ] 投资策略 Playbook - [ ] 决策效果评估与学习 --- ## 📄 许可证 本项目遵循 AgenticX 的许可证。 --- ## 🙏 致谢 - [AgenticX](https://github.com/yourusername/AgenticX) - 多智能体框架 - [FastAPI](https://fastapi.tiangolo.com/) - Web 框架 - [Milvus](https://milvus.io/) - 向量数据库 - [阿里云百炼](https://dashscope.console.aliyun.com/) - LLM 服务 - [Shadcn UI](https://ui.shadcn.com/) - 前端组件库 --- ## ⭐ Star History 如果你觉得这个项目对你有帮助,欢迎给个 Star ⭐️! [![Star History Chart](https://api.star-history.com/svg?repos=DemonDamon/FinnewsHunter&type=Date)](https://star-history.com/#DemonDamon/FinnewsHunter&Date) --- **Built with ❤️ using AgenticX** ================================================ FILE: backend/.gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python env/ venv/ ENV/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # Environment variables .env .env.local # IDE .vscode/ .idea/ *.swp *.swo *~ # Logs logs/ *.log # Database *.db *.sqlite # OS .DS_Store Thumbs.db # Testing .pytest_cache/ .coverage htmlcov/ celerybeat-schedule celerybeat-schedule celerybeat-schedule ================================================ FILE: backend/README.md ================================================ # FinnewsHunter Backend Backend service for the financial news intelligent analysis system based on the AgenticX framework. ## Documentation Navigation ### Quick Start - **[QUICKSTART.md](../QUICKSTART.md)** - Quick start guide (recommended for beginners) ### Configuration Guides - **[CONFIG_GUIDE.md](CONFIG_GUIDE.md)** - **Unified Configuration Guide** (recommended) - Single configuration file supports all LLM providers - Quick switching between OpenAI / Bailian / Proxy - Includes scenario examples and working principles - **[env.example](env.example)** - Configuration template (with comments for all scenarios) ### Specialized Configuration - **[BAILIAN_SETUP.md](BAILIAN_SETUP.md)** - Detailed Alibaba Cloud Bailian configuration (recommended for Chinese users) - **[API_PROXY_GUIDE.md](API_PROXY_GUIDE.md)** - API proxy configuration guide --- ## Quick Configuration ### Method 1: Interactive Script (Recommended) ```bash chmod +x setup_env.sh ./setup_env.sh # Follow the prompts to select: # 1) OpenAI Official # 2) Alibaba Cloud Bailian (recommended for Chinese users) # 3) Other Proxy # 4) Manual Configuration ``` ### Method 2: Manual Configuration ```bash cp env.example .env nano .env # Choose configuration scheme according to comments ``` --- ## Main Features - **Multi-Agent System**: Based on AgenticX framework - NewsAnalyst: News analysis agent - More agents under development... - **Data Collection**: - Sina Finance crawler - JRJ Finance crawler - **Storage System**: - PostgreSQL: Relational data storage - Milvus: Vector database - Redis: Cache and task queue - **LLM Support**: - OpenAI (GPT-3.5/GPT-4) - Alibaba Cloud Bailian (Qwen) - Other OpenAI-compatible services --- ## Project Structure ``` backend/ ├── app/ │ ├── agents/ # Agent definitions │ ├── api/ # FastAPI routes │ ├── core/ # Core configuration │ ├── models/ # Data models │ ├── services/ # Business services │ ├── storage/ # Storage wrappers │ └── tools/ # Crawlers and tools ├── logs/ # Log files ├── tests/ # Test files ├── .env # Environment configuration (copy from env.example) ├── env.example # Configuration template ├── requirements.txt # Python dependencies └── start.sh # Startup script ``` --- ## Development Guide ### Start Development Environment ```bash # 1. Configure environment variables ./setup_env.sh # 2. Start services (including Docker containers) ./start.sh ``` ### Utility Scripts The project provides some utility scripts located in the `tests/` directory: ```bash # Check Milvus vector storage data python tests/check_milvus_data.py # Check news embedding status python tests/check_news_embedding_status.py # Manually vectorize a specific news item (for fixing unvectorized news) python tests/manual_vectorize.py ``` ### View Logs ```bash tail -f logs/finnews.log ``` --- ## Common Configuration Scenarios ### OpenAI Official ```bash LLM_MODEL=gpt-3.5-turbo OPENAI_API_KEY=sk-openai-key MILVUS_DIM=1536 ``` ### Alibaba Cloud Bailian (Recommended for Chinese Users) ```bash LLM_MODEL=qwen-plus OPENAI_API_KEY=sk-bailian-key OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 MILVUS_DIM=1024 ``` ### OpenAI Proxy ```bash LLM_MODEL=gpt-3.5-turbo OPENAI_API_KEY=sk-proxy-key OPENAI_BASE_URL=https://your-proxy.com/v1 MILVUS_DIM=1536 ``` For detailed information, see **[CONFIG_GUIDE.md](CONFIG_GUIDE.md)** --- ## API Documentation - Swagger UI: http://localhost:8000/docs - ReDoc: http://localhost:8000/redoc ### Troubleshooting If the documentation page appears blank or keeps loading: 1. **Check Browser Console**: Press F12 to open developer tools, check Console and Network tabs for errors 2. **Try ReDoc**: If Swagger UI fails to load, try accessing ReDoc (uses a different CDN) 3. **Clear Browser Cache**: Press `Ctrl+Shift+R` (Windows/Linux) or `Cmd+Shift+R` (Mac) to force refresh 4. **Check Network Connection**: Documentation pages need to load JavaScript resources from CDN, ensure network connection is normal 5. **Check Backend Service**: Ensure the backend service is running, verify by accessing http://localhost:8000/health ================================================ FILE: backend/README_zn.md ================================================ # FinnewsHunter Backend 基于 AgenticX 框架的金融新闻智能分析系统后端服务。 ## 文档导航 ### 快速开始 - **[QUICKSTART.md](../QUICKSTART.md)** - 快速启动指南(推荐新手阅读) ### 配置指南 - **[CONFIG_GUIDE.md](CONFIG_GUIDE.md)** - **统一配置指南**(推荐首选) - 一个配置文件支持所有 LLM 服务商 - 快速切换 OpenAI / 百炼 / 代理 - 包含场景示例和工作原理 - **[env.example](env.example)** - 配置模板(包含所有场景的注释) ### 专项配置 - **[BAILIAN_SETUP.md](BAILIAN_SETUP.md)** - 阿里云百炼详细配置(国内用户推荐) - **[API_PROXY_GUIDE.md](API_PROXY_GUIDE.md)** - API 代理配置详解 --- ## 快速配置 ### 方法 1: 交互式脚本(推荐) ```bash chmod +x setup_env.sh ./setup_env.sh # 按提示选择: # 1) OpenAI 官方 # 2) 阿里云百炼(推荐国内用户) # 3) 其他代理 # 4) 手动配置 ``` ### 方法 2: 手动配置 ```bash cp env.example .env nano .env # 根据注释选择配置方案 ``` --- ## 主要功能 - **多智能体系统**:基于 AgenticX 框架 - NewsAnalyst:新闻分析智能体 - 更多智能体开发中... - **数据采集**: - 新浪财经爬虫 - 金融界爬虫 - **存储系统**: - PostgreSQL:关系数据存储 - Milvus:向量数据库 - Redis:缓存和任务队列 - **LLM 支持**: - OpenAI (GPT-3.5/GPT-4) - 阿里云百炼(通义千问) - 其他 OpenAI 兼容服务 --- ## 项目结构 ``` backend/ ├── app/ │ ├── agents/ # 智能体定义 │ ├── api/ # FastAPI 路由 │ ├── core/ # 核心配置 │ ├── models/ # 数据模型 │ ├── services/ # 业务服务 │ ├── storage/ # 存储封装 │ └── tools/ # 爬虫和工具 ├── logs/ # 日志文件 ├── tests/ # 测试文件 ├── .env # 环境配置(从 env.example 复制) ├── env.example # 配置模板 ├── requirements.txt # Python 依赖 └── start.sh # 启动脚本 ``` --- ## 开发指南 ### 启动开发环境 ```bash # 1. 配置环境变量 ./setup_env.sh # 2. 启动服务(包括 Docker 容器) ./start.sh ``` ### 工具脚本 项目提供了一些实用工具脚本,位于 `tests/` 目录下: ```bash # 检查 Milvus 向量存储数据 python tests/check_milvus_data.py # 检查新闻向量化状态 python tests/check_news_embedding_status.py # 手动向量化指定新闻(用于修复未向量化的新闻) python tests/manual_vectorize.py ``` ### 查看日志 ```bash tail -f logs/finnews.log ``` --- ## 常用配置场景 ### OpenAI 官方 ```bash LLM_MODEL=gpt-3.5-turbo OPENAI_API_KEY=sk-openai-key MILVUS_DIM=1536 ``` ### 阿里云百炼(推荐国内) ```bash LLM_MODEL=qwen-plus OPENAI_API_KEY=sk-bailian-key OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 MILVUS_DIM=1024 ``` ### OpenAI 代理 ```bash LLM_MODEL=gpt-3.5-turbo OPENAI_API_KEY=sk-proxy-key OPENAI_BASE_URL=https://your-proxy.com/v1 MILVUS_DIM=1536 ``` 详细说明见 **[CONFIG_GUIDE.md](CONFIG_GUIDE.md)** --- ## API 文档 - Swagger UI: http://localhost:8000/docs - ReDoc: http://localhost:8000/redoc ### 手动触发爬取 如果某个新闻源显示为空,可以手动触发实时爬取: ```bash # 触发腾讯财经爬取 curl -X POST "http://localhost:8000/api/v1/tasks/realtime" \ -H "Content-Type: application/json" \ -d '{"source": "tencent", "force_refresh": true}' # 触发经济观察网爬取 curl -X POST "http://localhost:8000/api/v1/tasks/realtime" \ -H "Content-Type: application/json" \ -d '{"source": "eeo", "force_refresh": true}' ``` 支持的新闻源: - `sina` - 新浪财经 - `tencent` - 腾讯财经 - `eeo` - 经济观察网 - `jwview` - 金融界 - `caijing` - 财经网 - `jingji21` - 21经济网 - `nbd` - 每日经济新闻 - `yicai` - 第一财经 - `163` - 网易财经 - `eastmoney` - 东方财富 ### 故障排查 如果文档页面显示空白或一直加载: 1. **检查浏览器控制台**:按 F12 打开开发者工具,查看 Console 和 Network 标签页是否有错误 2. **尝试 ReDoc**:如果 Swagger UI 无法加载,尝试访问 ReDoc(使用不同的 CDN) 3. **清除浏览器缓存**:按 `Ctrl+Shift+R` (Windows/Linux) 或 `Cmd+Shift+R` (Mac) 强制刷新 4. **检查网络连接**:文档页面需要从 CDN 加载 JavaScript 资源,确保网络连接正常 5. **检查后端服务**:确保后端服务正在运行,可以访问 http://localhost:8000/health 验证 ================================================ FILE: backend/add_raw_html_column.py ================================================ """ 数据库迁移:添加 raw_html 字段 """ import os from pathlib import Path from dotenv import load_dotenv # 加载环境变量 env_path = Path(__file__).parent / ".env" load_dotenv(env_path) # 构建数据库 URL POSTGRES_USER = os.getenv("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost") POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432") POSTGRES_DB = os.getenv("POSTGRES_DB", "finnews_db") DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" from sqlalchemy import create_engine, text def add_raw_html_column(): """添加 raw_html 字段到 news 表""" print("🔧 正在添加 raw_html 字段...") engine = create_engine(DATABASE_URL) with engine.connect() as conn: # 检查字段是否已存在 result = conn.execute(text(""" SELECT column_name FROM information_schema.columns WHERE table_name = 'news' AND column_name = 'raw_html' """)) if result.fetchone(): print("✅ raw_html 字段已存在,无需迁移") return # 添加字段 conn.execute(text(""" ALTER TABLE news ADD COLUMN raw_html TEXT """)) conn.commit() print("✅ raw_html 字段已添加成功!") if __name__ == "__main__": print("=" * 50) print("📦 数据库迁移:添加 raw_html 字段") print("=" * 50) add_raw_html_column() ================================================ FILE: backend/app/__init__.py ================================================ """ FinnewsHunter Backend Application """ __version__ = "0.1.0" ================================================ FILE: backend/app/agents/__init__.py ================================================ """ 智能体模块 """ from .news_analyst import NewsAnalystAgent, create_news_analyst from .debate_agents import ( BullResearcherAgent, BearResearcherAgent, InvestmentManagerAgent, DebateWorkflow, create_debate_workflow, ) from .data_collector_v2 import DataCollectorAgentV2, QuickAnalystAgent, create_data_collector from .orchestrator import DebateOrchestrator, create_orchestrator from .quantitative_agent import QuantitativeAgent, create_quantitative_agent __all__ = [ "NewsAnalystAgent", "create_news_analyst", "BullResearcherAgent", "BearResearcherAgent", "InvestmentManagerAgent", "DebateWorkflow", "create_debate_workflow", "DataCollectorAgentV2", "QuickAnalystAgent", "create_data_collector", "DebateOrchestrator", "create_orchestrator", "QuantitativeAgent", "create_quantitative_agent", ] ================================================ FILE: backend/app/agents/data_collector.py ================================================ """ 数据专员智能体 负责在辩论前搜集和整理相关数据资料,包括: - 新闻数据(从数据库或BochaAI搜索) - 财务数据(从AkShare获取) - 行情数据(实时行情、K线等) """ import logging from typing import Dict, Any, List, Optional from datetime import datetime from agenticx.core.agent import Agent from ..services.llm_service import get_llm_provider logger = logging.getLogger(__name__) class DataCollectorAgent(Agent): """数据专员智能体""" def __init__(self, llm_provider=None, organization_id: str = "finnews"): super().__init__( name="DataCollector", role="数据专员", goal="搜集和整理股票相关的新闻、财务和行情数据,为辩论提供全面的信息支持", backstory="""你是一位专业的金融数据分析师,擅长从多个数据源搜集和整理信息。 你的职责是在辩论开始前,为Bull/Bear研究员提供全面、准确、及时的数据支持。 你需要: 1. 搜集最新的相关新闻 2. 获取关键财务指标 3. 分析资金流向 4. 整理行情数据 你的工作质量直接影响辩论的深度和专业性。""", organization_id=organization_id ) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") async def collect_data( self, stock_code: str, stock_name: str, data_requirements: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 搜集股票相关数据 Args: stock_code: 股票代码 stock_name: 股票名称 data_requirements: 数据需求配置 Returns: 包含各类数据的字典 """ logger.info(f"📊 DataCollector: 开始搜集 {stock_name}({stock_code}) 的数据...") result = { "stock_code": stock_code, "stock_name": stock_name, "collected_at": datetime.utcnow().isoformat(), "news": [], "financial": {}, "fund_flow": {}, "realtime_quote": {}, "summary": "" } try: # 1. 搜集新闻数据 news_data = await self._collect_news(stock_code, stock_name) result["news"] = news_data logger.info(f"📰 DataCollector: 搜集到 {len(news_data)} 条新闻") # 2. 搜集财务数据 financial_data = await self._collect_financial(stock_code) result["financial"] = financial_data logger.info(f"💰 DataCollector: 搜集到财务数据") # 3. 搜集资金流向 fund_flow = await self._collect_fund_flow(stock_code) result["fund_flow"] = fund_flow logger.info(f"💸 DataCollector: 搜集到资金流向数据") # 4. 搜集实时行情 realtime = await self._collect_realtime_quote(stock_code) result["realtime_quote"] = realtime logger.info(f"📈 DataCollector: 搜集到实时行情") # 5. 生成数据摘要 result["summary"] = await self._generate_summary(result) logger.info(f"📋 DataCollector: 数据摘要生成完成") except Exception as e: logger.error(f"DataCollector 搜集数据时出错: {e}", exc_info=True) result["error"] = str(e) return result async def _collect_news(self, stock_code: str, stock_name: str) -> List[Dict[str, Any]]: """搜集新闻数据""" from ..services.news_service import news_service try: # 从数据库获取已有新闻 news_list = await news_service.get_news_by_stock(stock_code, limit=20) return [ { "title": news.title, "content": news.content[:500] if news.content else "", "source": news.source, "published_at": news.published_at.isoformat() if news.published_at else None, "sentiment": news.sentiment } for news in news_list ] except Exception as e: logger.warning(f"从数据库获取新闻失败: {e}") return [] async def _collect_financial(self, stock_code: str) -> Dict[str, Any]: """搜集财务数据""" from ..services.stock_data_service import stock_data_service try: return await stock_data_service.get_financial_indicators(stock_code) or {} except Exception as e: logger.warning(f"获取财务数据失败: {e}") return {} async def _collect_fund_flow(self, stock_code: str) -> Dict[str, Any]: """搜集资金流向数据""" from ..services.stock_data_service import stock_data_service try: return await stock_data_service.get_fund_flow(stock_code) or {} except Exception as e: logger.warning(f"获取资金流向失败: {e}") return {} async def _collect_realtime_quote(self, stock_code: str) -> Dict[str, Any]: """搜集实时行情""" from ..services.stock_data_service import stock_data_service try: return await stock_data_service.get_realtime_quote(stock_code) or {} except Exception as e: logger.warning(f"获取实时行情失败: {e}") return {} async def _generate_summary(self, data: Dict[str, Any]) -> str: """使用LLM生成数据摘要""" try: # 准备摘要内容 news_summary = "" if data.get("news"): news_titles = [n["title"] for n in data["news"][:5]] news_summary = f"最新新闻({len(data['news'])}条):\n" + "\n".join(f"- {t}" for t in news_titles) financial_summary = "" if data.get("financial"): f = data["financial"] financial_summary = f"""财务指标: - PE: {f.get('pe', 'N/A')} - PB: {f.get('pb', 'N/A')} - ROE: {f.get('roe', 'N/A')} - 净利润增长率: {f.get('net_profit_growth', 'N/A')}""" fund_flow_summary = "" if data.get("fund_flow"): ff = data["fund_flow"] fund_flow_summary = f"""资金流向: - 主力净流入: {ff.get('main_net_inflow', 'N/A')} - 散户净流入: {ff.get('retail_net_inflow', 'N/A')}""" realtime_summary = "" if data.get("realtime_quote"): rt = data["realtime_quote"] realtime_summary = f"""实时行情: - 当前价: {rt.get('price', 'N/A')} - 涨跌幅: {rt.get('change_pct', 'N/A')}% - 成交量: {rt.get('volume', 'N/A')}""" summary = f"""## {data['stock_name']}({data['stock_code']}) 数据摘要 {realtime_summary} {financial_summary} {fund_flow_summary} {news_summary} 数据搜集时间: {data['collected_at']}""" return summary except Exception as e: logger.error(f"生成数据摘要失败: {e}") return f"数据搜集完成,但生成摘要时出错: {e}" async def analyze_data_quality(self, data: Dict[str, Any]) -> Dict[str, Any]: """分析数据质量和完整性""" quality = { "score": 0, "max_score": 100, "details": [], "recommendations": [] } # 检查新闻数据 news_count = len(data.get("news", [])) if news_count >= 10: quality["score"] += 30 quality["details"].append(f"✅ 新闻数据充足({news_count}条)") elif news_count >= 5: quality["score"] += 20 quality["details"].append(f"⚠️ 新闻数据较少({news_count}条)") quality["recommendations"].append("建议搜集更多新闻以支持分析") elif news_count > 0: quality["score"] += 10 quality["details"].append(f"⚠️ 新闻数据不足({news_count}条)") quality["recommendations"].append("新闻数据偏少,分析可能不够全面") else: quality["details"].append("❌ 无新闻数据") quality["recommendations"].append("缺少新闻数据,建议先进行定向爬取") # 检查财务数据 if data.get("financial"): quality["score"] += 25 quality["details"].append("✅ 财务数据完整") else: quality["details"].append("❌ 缺少财务数据") quality["recommendations"].append("无法获取财务指标") # 检查资金流向 if data.get("fund_flow"): quality["score"] += 20 quality["details"].append("✅ 资金流向数据完整") else: quality["details"].append("⚠️ 缺少资金流向数据") # 检查实时行情 if data.get("realtime_quote"): quality["score"] += 25 quality["details"].append("✅ 实时行情数据完整") else: quality["details"].append("⚠️ 缺少实时行情数据") return quality # 快速分析师(用于快速分析模式) class QuickAnalystAgent(Agent): """快速分析师智能体""" def __init__(self, llm_provider=None, organization_id: str = "finnews"): super().__init__( name="QuickAnalyst", role="快速分析师", goal="快速综合多角度给出投资建议", backstory="""你是一位经验丰富的量化分析师,擅长快速分析和决策。 你能够在短时间内综合考虑多空因素,给出简洁明了的投资建议。 你的分析风格是:快速、准确、实用。""", organization_id=organization_id ) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") async def quick_analyze( self, stock_code: str, stock_name: str, context: str ) -> Dict[str, Any]: """快速分析""" # 获取当前系统时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") prompt = f"""请对 {stock_name}({stock_code}) 进行快速投资分析。 【当前时间】 {current_time} 背景资料: {context} 请在1分钟内给出: 1. 核心观点(一句话) 2. 看多因素(3点) 3. 看空因素(3点) 4. 投资建议(买入/持有/卖出) 5. 目标价位和止损价位 请用简洁的语言,直接给出结论。""" try: response = await self._llm_provider.chat(prompt) return { "success": True, "analysis": response, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Quick analysis failed: {e}") return { "success": False, "error": str(e) } ================================================ FILE: backend/app/agents/data_collector_v2.py ================================================ """ 数据专员智能体 V2 (DataCollectorAgent) 统一负责所有数据获取任务,支持: - 辩论前的初始数据收集 - 辩论中的动态数据补充 - 用户追问时的按需搜索 核心特性: 1. 计划/执行分离:先生成搜索计划,用户确认后再执行 2. 多数据源支持:AkShare、BochaAI、网页搜索、知识库 3. 智能意图识别:根据用户问题自动选择数据源 """ import logging import re import asyncio from typing import Dict, Any, List, Optional, ClassVar, Pattern from datetime import datetime from enum import Enum from pydantic import BaseModel, Field from agenticx.core.agent import Agent from ..services.llm_service import get_llm_provider from ..services.stock_data_service import stock_data_service from ..tools.bochaai_search import bochaai_search, SearchResult from ..tools.interactive_crawler import InteractiveCrawler logger = logging.getLogger(__name__) class SearchSource(str, Enum): """搜索数据源类型""" AKSHARE = "akshare" # AkShare 财务/行情数据 BOCHAAI = "bochaai" # BochaAI Web搜索 BROWSER = "browser" # 交互式浏览器搜索 KNOWLEDGE_BASE = "kb" # 内部知识库 ALL = "all" # 所有来源 class SearchTask(BaseModel): """单个搜索任务""" id: str = Field(..., description="任务ID") source: SearchSource = Field(..., description="数据源") query: str = Field(..., description="搜索查询") description: str = Field("", description="任务描述(用于展示给用户)") data_type: Optional[str] = Field(None, description="数据类型(如 financial, news, kline)") icon: str = Field("🔍", description="图标(用于UI展示)") estimated_time: int = Field(3, description="预计耗时(秒)") class SearchPlan(BaseModel): """搜索计划""" plan_id: str = Field(..., description="计划ID") stock_code: str = Field(..., description="股票代码") stock_name: str = Field("", description="股票名称") user_query: str = Field(..., description="用户原始问题") tasks: List[SearchTask] = Field(default_factory=list, description="搜索任务列表") total_estimated_time: int = Field(0, description="总预计耗时(秒)") created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) status: str = Field("pending", description="状态:pending, confirmed, executing, completed, cancelled") class SearchResult(BaseModel): """搜索结果""" task_id: str source: str success: bool data: Dict[str, Any] = Field(default_factory=dict) summary: str = "" error: Optional[str] = None execution_time: float = 0 class DataCollectorAgentV2(Agent): """ 数据专员智能体 V2 支持"确认优先"模式: 1. 用户 @数据专员 提问 2. 生成搜索计划(不执行) 3. 用户确认后执行 4. 返回结果 """ # 关键词到数据源的映射 KEYWORD_SOURCE_MAP: ClassVar[Dict[str, tuple]] = { # 财务相关 -> AkShare "财务": (SearchSource.AKSHARE, "financial", "📊"), "pe": (SearchSource.AKSHARE, "financial", "📊"), "pb": (SearchSource.AKSHARE, "financial", "📊"), "roe": (SearchSource.AKSHARE, "financial", "📊"), "利润": (SearchSource.AKSHARE, "financial", "📊"), "营收": (SearchSource.AKSHARE, "financial", "📊"), "估值": (SearchSource.AKSHARE, "financial", "📊"), "市盈": (SearchSource.AKSHARE, "financial", "📊"), "市净": (SearchSource.AKSHARE, "financial", "📊"), "报表": (SearchSource.AKSHARE, "financial", "📊"), # 资金/行情 -> AkShare "资金": (SearchSource.AKSHARE, "fund_flow", "💰"), "主力": (SearchSource.AKSHARE, "fund_flow", "💰"), "流入": (SearchSource.AKSHARE, "fund_flow", "💰"), "流出": (SearchSource.AKSHARE, "fund_flow", "💰"), "行情": (SearchSource.AKSHARE, "realtime", "📈"), "价格": (SearchSource.AKSHARE, "realtime", "📈"), "涨跌": (SearchSource.AKSHARE, "realtime", "📈"), "k线": (SearchSource.AKSHARE, "kline", "📈"), "走势": (SearchSource.AKSHARE, "kline", "📈"), # 新闻相关 -> BochaAI "新闻": (SearchSource.BOCHAAI, "news", "📰"), "资讯": (SearchSource.BOCHAAI, "news", "📰"), "报道": (SearchSource.BOCHAAI, "news", "📰"), "公告": (SearchSource.BOCHAAI, "news", "📰"), "消息": (SearchSource.BOCHAAI, "news", "📰"), # 上下游/产业链 -> 多源搜索 "上下游": (SearchSource.BROWSER, "industry", "🔗"), "供应链": (SearchSource.BROWSER, "industry", "🔗"), "客户": (SearchSource.BROWSER, "industry", "🔗"), "供应商": (SearchSource.BROWSER, "industry", "🔗"), "合作": (SearchSource.BROWSER, "industry", "🔗"), "产业链": (SearchSource.BROWSER, "industry", "🔗"), } def __init__(self, llm_provider=None, organization_id: str = "finnews"): super().__init__( name="DataCollector", role="数据专员", goal="根据用户需求,从多个数据源搜集和整理相关信息,支持辩论前准备和辩论中追问", backstory="""你是一位专业的金融数据专家,精通各类金融数据源的使用。 你的职责是: 1. 理解用户的数据需求 2. 制定合理的搜索计划 3. 从多个数据源获取数据 4. 整理并格式化数据 你能够访问的数据源包括: - AkShare: 股票财务指标、K线行情、资金流向等 - BochaAI: 实时新闻搜索、财经报道 - 网页搜索: 百度资讯、搜狗等 - 知识库: 历史新闻和分析数据""", organization_id=organization_id ) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) # 初始化搜索工具 self._interactive_crawler = InteractiveCrawler(timeout=20) logger.info(f"✅ Initialized DataCollectorV2 with multi-source search capabilities") async def generate_search_plan( self, query: str, stock_code: str, stock_name: str = "" ) -> SearchPlan: """ 生成搜索计划(不执行) 根据用户问题分析需要哪些数据,生成待确认的搜索计划 Args: query: 用户问题 stock_code: 股票代码 stock_name: 股票名称 Returns: SearchPlan 对象 """ logger.info(f"📋 DataCollector: 为 '{query}' 生成搜索计划...") plan_id = f"plan_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{stock_code}" plan = SearchPlan( plan_id=plan_id, stock_code=stock_code, stock_name=stock_name or stock_code, user_query=query, tasks=[], status="pending" ) query_lower = query.lower() # 1. 基于关键词匹配生成任务 matched_sources = set() for keyword, (source, data_type, icon) in self.KEYWORD_SOURCE_MAP.items(): if keyword in query_lower: if (source, data_type) not in matched_sources: matched_sources.add((source, data_type)) task = self._create_task( source=source, data_type=data_type, icon=icon, query=query, stock_code=stock_code, stock_name=stock_name ) plan.tasks.append(task) # 2. 如果没有匹配到任何关键词,使用 LLM 分析 if not plan.tasks: plan.tasks = await self._analyze_with_llm(query, stock_code, stock_name) # 3. 如果还是没有任务,添加默认的综合搜索 if not plan.tasks: plan.tasks = [ SearchTask( id=f"task_{plan_id}_1", source=SearchSource.BOCHAAI, query=f"{stock_name or stock_code} {query}", description=f"搜索 {stock_name} 相关新闻", icon="📰", estimated_time=3 ), SearchTask( id=f"task_{plan_id}_2", source=SearchSource.AKSHARE, query=query, description="获取最新财务和行情数据", data_type="overview", icon="📊", estimated_time=2 ) ] # 计算总耗时 plan.total_estimated_time = sum(t.estimated_time for t in plan.tasks) logger.info(f"✅ 生成搜索计划: {len(plan.tasks)} 个任务,预计耗时 {plan.total_estimated_time}s") return plan def _create_task( self, source: SearchSource, data_type: str, icon: str, query: str, stock_code: str, stock_name: str ) -> SearchTask: """创建搜索任务""" task_id = f"task_{datetime.utcnow().strftime('%H%M%S%f')}" # 根据数据类型生成描述 descriptions = { "financial": f"获取 {stock_name or stock_code} 财务指标(PE/PB/ROE等)", "fund_flow": f"获取 {stock_name or stock_code} 资金流向(主力/散户)", "realtime": f"获取 {stock_name or stock_code} 实时行情", "kline": f"获取 {stock_name or stock_code} K线走势", "news": f"搜索 {stock_name or stock_code} 最新新闻", "industry": f"搜索 {stock_name or stock_code} 产业链/上下游信息", } # 根据数据类型生成查询 queries = { "financial": stock_code, "fund_flow": stock_code, "realtime": stock_code, "kline": stock_code, "news": f"{stock_name or stock_code} {query}", "industry": f"{stock_name or stock_code} {query}", } return SearchTask( id=task_id, source=source, query=queries.get(data_type, query), description=descriptions.get(data_type, f"搜索: {query}"), data_type=data_type, icon=icon, estimated_time=3 if source != SearchSource.BROWSER else 5 ) async def _analyze_with_llm( self, query: str, stock_code: str, stock_name: str ) -> List[SearchTask]: """使用 LLM 分析需要哪些数据""" try: prompt = f"""分析以下用户问题,判断需要搜索哪些数据: 用户问题: "{query}" 股票: {stock_name}({stock_code}) 可用数据源: 1. akshare - 财务数据(PE/PB/ROE等)、资金流向、实时行情、K线 2. bochaai - 新闻搜索、财经报道 3. browser - 网页搜索(适合搜索产业链、上下游、合作方等) 4. kb - 历史新闻数据库 请返回需要搜索的内容,格式如下(每行一个): SOURCE:数据源|TYPE:数据类型|QUERY:搜索词|DESC:描述 示例: SOURCE:bochaai|TYPE:news|QUERY:ST国华 上下游|DESC:搜索ST国华上下游相关新闻 SOURCE:akshare|TYPE:financial|QUERY:002074|DESC:获取国轩高科财务数据 只输出2-4个最相关的搜索任务。""" response = self._llm_provider.invoke([ {"role": "system", "content": "你是数据搜索专家,帮助分析需要哪些数据。"}, {"role": "user", "content": prompt} ]) content = response.content if hasattr(response, 'content') else str(response) tasks = [] for line in content.strip().split('\n'): if 'SOURCE:' in line: try: parts = {} for part in line.split('|'): if ':' in part: key, value = part.split(':', 1) parts[key.strip().upper()] = value.strip() if 'SOURCE' in parts: source_str = parts['SOURCE'].lower() source = SearchSource(source_str) if source_str in [s.value for s in SearchSource] else SearchSource.BOCHAAI tasks.append(SearchTask( id=f"task_llm_{len(tasks)+1}", source=source, query=parts.get('QUERY', query), description=parts.get('DESC', f"搜索: {query}"), data_type=parts.get('TYPE', 'general'), icon=self._get_icon_for_source(source), estimated_time=3 )) except Exception as e: logger.debug(f"解析 LLM 响应行失败: {e}") return tasks except Exception as e: logger.warning(f"LLM 分析失败: {e}") return [] def _get_icon_for_source(self, source: SearchSource) -> str: """获取数据源对应的图标""" icons = { SearchSource.AKSHARE: "📊", SearchSource.BOCHAAI: "📰", SearchSource.BROWSER: "🌐", SearchSource.KNOWLEDGE_BASE: "📚", SearchSource.ALL: "🔍" } return icons.get(source, "🔍") async def execute_search_plan( self, plan: SearchPlan ) -> Dict[str, Any]: """ 执行搜索计划 Args: plan: 已确认的搜索计划 Returns: 搜索结果汇总 """ logger.info(f"🚀 DataCollector: 开始执行搜索计划 {plan.plan_id}...") plan.status = "executing" start_time = datetime.utcnow() results = { "plan_id": plan.plan_id, "stock_code": plan.stock_code, "stock_name": plan.stock_name, "user_query": plan.user_query, "task_results": [], "combined_data": {}, "summary": "", "success": False, "execution_time": 0 } # 并行执行所有任务 async_tasks = [] for task in plan.tasks: async_tasks.append(self._execute_task(task, plan.stock_code, plan.stock_name)) task_results = await asyncio.gather(*async_tasks, return_exceptions=True) # 收集结果 for i, result in enumerate(task_results): if isinstance(result, Exception): logger.error(f"任务执行失败: {result}") results["task_results"].append(SearchResult( task_id=plan.tasks[i].id, source=plan.tasks[i].source.value, success=False, error=str(result) ).dict()) else: results["task_results"].append(result.dict() if hasattr(result, 'dict') else result) if result.get("success"): # 合并数据 source = result.get("source", "unknown") if source not in results["combined_data"]: results["combined_data"][source] = {} results["combined_data"][source].update(result.get("data", {})) # 生成综合摘要 results["summary"] = await self._generate_combined_summary( plan.user_query, results["combined_data"], plan.stock_name ) # 计算执行时间 end_time = datetime.utcnow() results["execution_time"] = (end_time - start_time).total_seconds() results["success"] = any(r.get("success") for r in results["task_results"]) plan.status = "completed" logger.info(f"✅ 搜索计划执行完成,耗时 {results['execution_time']:.1f}s") return results async def _execute_task( self, task: SearchTask, stock_code: str, stock_name: str ) -> Dict[str, Any]: """执行单个搜索任务""" logger.info(f"🔍 执行任务: {task.description}") start_time = datetime.utcnow() result = { "task_id": task.id, "source": task.source.value, "success": False, "data": {}, "summary": "", "execution_time": 0 } try: if task.source == SearchSource.AKSHARE: data = await self._search_akshare(task.query, stock_code, task.data_type) result["data"] = data or {} result["success"] = bool(data) elif task.source == SearchSource.BOCHAAI: data = await self._search_bochaai(task.query, stock_name) result["data"] = data or {} result["success"] = bool(data) elif task.source == SearchSource.BROWSER: data = await self._search_browser(task.query) result["data"] = data or {} result["success"] = bool(data) elif task.source == SearchSource.KNOWLEDGE_BASE: data = await self._search_knowledge_base(task.query, stock_code) result["data"] = data or {} result["success"] = bool(data) except Exception as e: logger.error(f"任务 {task.id} 执行失败: {e}") result["error"] = str(e) end_time = datetime.utcnow() result["execution_time"] = (end_time - start_time).total_seconds() return result async def _search_akshare( self, query: str, stock_code: str, data_type: Optional[str] = None ) -> Optional[Dict[str, Any]]: """从 AkShare 获取数据""" data = {} try: if data_type == "financial" or data_type == "overview": financial = await stock_data_service.get_financial_indicators(stock_code) if financial: data["financial_indicators"] = financial if data_type == "fund_flow" or data_type == "overview": fund_flow = await stock_data_service.get_fund_flow(stock_code, days=10) if fund_flow: data["fund_flow"] = fund_flow if data_type == "realtime" or data_type == "overview": realtime = await stock_data_service.get_realtime_quote(stock_code) if realtime: data["realtime_quote"] = realtime if data_type == "kline": kline = await stock_data_service.get_kline_data(stock_code, period="daily", limit=30) if kline: data["kline_summary"] = { "period": "daily", "count": len(kline), "latest": kline[-1] if kline else None, "recent_5": kline[-5:] if len(kline) >= 5 else kline } if data: logger.info(f"✅ AkShare 返回数据: {list(data.keys())}") return data except Exception as e: logger.warning(f"AkShare 搜索出错: {e}") return None async def _search_bochaai( self, query: str, stock_name: Optional[str] = None ) -> Optional[Dict[str, Any]]: """从 BochaAI 搜索新闻""" if not bochaai_search.is_available(): logger.debug("BochaAI 未配置,跳过") return None try: results = bochaai_search.search( query=query, freshness="oneWeek", count=10 ) if results: news_list = [ { "title": r.title, "snippet": r.snippet[:200] if r.snippet else "", "url": r.url, "source": r.site_name or "unknown", "date": r.date_published or "" } for r in results ] logger.info(f"✅ BochaAI 返回 {len(news_list)} 条新闻") return {"news": news_list, "count": len(news_list)} except Exception as e: logger.warning(f"BochaAI 搜索出错: {e}") return None async def _search_browser(self, query: str) -> Optional[Dict[str, Any]]: """使用交互式爬虫搜索""" try: loop = asyncio.get_event_loop() results = await loop.run_in_executor( None, lambda: self._interactive_crawler.interactive_search( query=query, engines=["baidu_news", "sogou"], num_results=10, search_type="news" ) ) if results: news_list = [ { "title": r.get("title", ""), "snippet": r.get("snippet", "")[:200], "url": r.get("url", ""), "source": "browser_search" } for r in results ] logger.info(f"✅ Browser 返回 {len(news_list)} 条结果") return {"search_results": news_list, "count": len(news_list)} except Exception as e: logger.warning(f"Browser 搜索出错: {e}") return None async def _search_knowledge_base( self, query: str, stock_code: str ) -> Optional[Dict[str, Any]]: """从知识库搜索历史数据""" try: from ..services.news_service import news_service if stock_code and news_service: news_list = await news_service.get_news_by_stock(stock_code, limit=10) if news_list: kb_news = [ { "title": getattr(news, 'title', ''), "content": (getattr(news, 'content', '') or '')[:300], "source": getattr(news, 'source', ''), "date": news.publish_time.isoformat() if hasattr(news, 'publish_time') and news.publish_time else "" } for news in news_list ] logger.info(f"✅ KB 返回 {len(kb_news)} 条历史新闻") return {"historical_news": kb_news, "count": len(kb_news)} except Exception as e: logger.debug(f"KB 搜索出错: {e}") return None async def _generate_combined_summary( self, query: str, data: Dict[str, Any], stock_name: str ) -> str: """生成综合摘要""" summary_parts = [f"## 搜索结果: {query}\n"] summary_parts.append(f"**股票**: {stock_name}\n") # AkShare 数据 if "akshare" in data: ak_data = data["akshare"] summary_parts.append("### 📊 财务/行情数据\n") if "financial_indicators" in ak_data: fi = ak_data["financial_indicators"] summary_parts.append(f"- PE: {fi.get('pe_ratio', 'N/A')}, PB: {fi.get('pb_ratio', 'N/A')}") summary_parts.append(f"- ROE: {fi.get('roe', 'N/A')}%") if "realtime_quote" in ak_data: rt = ak_data["realtime_quote"] summary_parts.append(f"- 当前价: {rt.get('price', 'N/A')}元, 涨跌幅: {rt.get('change_percent', 'N/A')}%") if "fund_flow" in ak_data: ff = ak_data["fund_flow"] summary_parts.append(f"- 资金流向: {ff.get('main_flow_trend', 'N/A')}") summary_parts.append("") # BochaAI 新闻 if "bochaai" in data: news = data["bochaai"].get("news", []) if news: summary_parts.append("### 📰 最新新闻\n") for i, n in enumerate(news[:5], 1): summary_parts.append(f"{i}. **{n['title'][:50]}**") if n.get('snippet'): summary_parts.append(f" {n['snippet'][:100]}...") summary_parts.append("") # Browser 结果 if "browser" in data: results = data["browser"].get("search_results", []) if results: summary_parts.append("### 🌐 网页搜索结果\n") for i, r in enumerate(results[:5], 1): summary_parts.append(f"{i}. {r['title'][:50]}") summary_parts.append("") # KB 历史数据 if "kb" in data: kb_news = data["kb"].get("historical_news", []) if kb_news: summary_parts.append("### 📚 历史资料\n") for i, n in enumerate(kb_news[:3], 1): summary_parts.append(f"{i}. {n['title'][:50]}") summary_parts.append("") return "\n".join(summary_parts) # ============ 兼容旧 API ============ async def collect_data( self, stock_code: str, stock_name: str, data_requirements: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 搜集股票相关数据(兼容旧 API) """ # 创建并执行一个全面的搜索计划 plan = await self.generate_search_plan( query="综合数据搜集", stock_code=stock_code, stock_name=stock_name ) # 添加所有基础数据任务 plan.tasks = [ SearchTask( id=f"task_init_1", source=SearchSource.AKSHARE, query=stock_code, description="获取财务和行情数据", data_type="overview", icon="📊", estimated_time=3 ), SearchTask( id=f"task_init_2", source=SearchSource.KNOWLEDGE_BASE, query=stock_code, description="获取历史新闻", data_type="news", icon="📚", estimated_time=2 ) ] return await self.execute_search_plan(plan) # 快速分析师(保持不变) class QuickAnalystAgent(Agent): """快速分析师智能体""" def __init__(self, llm_provider=None, organization_id: str = "finnews"): super().__init__( name="QuickAnalyst", role="快速分析师", goal="快速综合多角度给出投资建议", backstory="""你是一位经验丰富的量化分析师,擅长快速分析和决策。 你能够在短时间内综合考虑多空因素,给出简洁明了的投资建议。 你的分析风格是:快速、准确、实用。""", organization_id=organization_id ) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") async def quick_analyze( self, stock_code: str, stock_name: str, context: str ) -> Dict[str, Any]: """快速分析""" current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") prompt = f"""请对 {stock_name}({stock_code}) 进行快速投资分析。 【当前时间】 {current_time} 背景资料: {context} 请在1分钟内给出: 1. 核心观点(一句话) 2. 看多因素(3点) 3. 看空因素(3点) 4. 投资建议(买入/持有/卖出) 5. 目标价位和止损价位 请用简洁的语言,直接给出结论。""" try: response = self._llm_provider.invoke([ {"role": "system", "content": "你是快速分析师,擅长快速给出投资建议。"}, {"role": "user", "content": prompt} ]) content = response.content if hasattr(response, 'content') else str(response) return { "success": True, "analysis": content, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Quick analysis failed: {e}") return { "success": False, "error": str(e) } # 工厂函数 def create_data_collector(llm_provider=None) -> DataCollectorAgentV2: """创建数据专员实例""" return DataCollectorAgentV2(llm_provider=llm_provider) ================================================ FILE: backend/app/agents/debate_agents.py ================================================ """ 辩论智能体 - Phase 2 实现 Bull vs Bear 多智能体辩论机制 支持动态搜索:智能体可以在发言中请求额外数据 格式: [SEARCH: "查询内容" source:数据源] """ import logging from typing import List, Dict, Any, Optional from datetime import datetime from agenticx import Agent from ..services.llm_service import get_llm_provider logger = logging.getLogger(__name__) # 数据请求提示词片段(用于启用动态搜索的场景) DATA_REQUEST_HINT = """ 【数据请求】如果需要更多数据支撑你的论点,可以在发言末尾添加搜索请求: - [SEARCH: "具体数据需求" source:akshare] -- 财务/行情数据 - [SEARCH: "新闻关键词" source:bochaai] -- 最新新闻 - [SEARCH: "搜索内容"] -- 自动选择最佳数据源 请只在确实需要时使用,每次最多1-2个请求。""" class BullResearcherAgent(Agent): """ 看多研究员智能体 职责:基于新闻和数据,生成看多观点和投资建议 支持在辩论中请求额外数据 """ def __init__(self, llm_provider=None, organization_id: str = "finnews"): # 先调用父类初始化(Pydantic BaseModel) super().__init__( name="BullResearcher", role="看多研究员", goal="从积极角度分析股票,发现投资机会和增长潜力", backstory="""你是一位乐观但理性的股票研究员,擅长发现被低估的投资机会。 你善于从新闻和数据中提取正面信息,分析公司的增长潜力、竞争优势和市场机遇。 你的分析注重长期价值,但也关注短期催化剂。 当你发现数据不足以支撑论点时,你会主动请求补充数据。""", organization_id=organization_id ) # 在 super().__init__() 之后设置 _llm_provider(避免被 Pydantic 清除) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") def analyze( self, stock_code: str, stock_name: str, news_list: List[Dict[str, Any]], context: str = "" ) -> Dict[str, Any]: """ 生成看多分析报告 """ news_summary = self._summarize_news(news_list) # 获取当前系统时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") prompt = f"""你是一位看多研究员,请从积极角度分析以下股票: 【当前时间】 {current_time} 【股票信息】 代码:{stock_code} 名称:{stock_name} 【相关新闻摘要】 {news_summary} 【分析背景】 {context if context else "无额外背景信息"} 请从以下角度进行看多分析: ## 1. 核心看多逻辑 - 列出3-5个看多的核心理由 - 每个理由需要有数据或新闻支撑 ## 2. 增长催化剂 - 短期催化剂(1-3个月内可能发生的利好) - 中长期催化剂(3-12个月的增长驱动力) ## 3. 估值分析 - 当前估值是否具有吸引力 - 与同行业对比的优势 ## 4. 目标预期 - 给出合理的预期收益空间 - 说明达成条件 ## 5. 风险提示 - 虽然看多,但也需要指出可能的风险 请确保分析客观、有理有据,避免盲目乐观。 """ try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": prompt} ]) analysis_text = response.content if hasattr(response, 'content') else str(response) return { "success": True, "agent_name": self.name, "agent_role": self.role, "stance": "bull", "analysis": analysis_text, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Bull analysis failed: {e}") return { "success": False, "agent_name": self.name, "stance": "bull", "error": str(e) } async def debate_round(self, prompt: str, enable_data_request: bool = True) -> str: """ 辩论回合发言(用于实时辩论模式) Args: prompt: 辩论提示词 enable_data_request: 是否启用数据请求功能 Returns: 发言内容(可能包含数据请求标记) """ system_content = f"""你是{self.role},{self.backstory} 你正在参与一场多空辩论,请用专业但有说服力的语气发言。 作为看多方,你的核心任务是: 1. 挖掘公司的增长潜力和投资价值 2. 用数据和事实支撑你的乐观观点 3. 反驳看空方提出的风险点 4. 识别被市场低估的机会""" if enable_data_request: system_content += DATA_REQUEST_HINT try: response = self._llm_provider.invoke([ {"role": "system", "content": system_content}, {"role": "user", "content": prompt} ]) return response.content if hasattr(response, 'content') else str(response) except Exception as e: logger.error(f"Bull debate round failed: {e}") return f"[发言出错: {e}]" def _summarize_news(self, news_list: List[Dict[str, Any]]) -> str: """汇总新闻信息""" if not news_list: return "暂无相关新闻" summaries = [] for i, news in enumerate(news_list[:5], 1): title = news.get("title", "") sentiment = news.get("sentiment_score") sentiment_text = "" if sentiment is not None: if sentiment > 0.1: sentiment_text = "(利好)" elif sentiment < -0.1: sentiment_text = "(利空)" else: sentiment_text = "(中性)" summaries.append(f"{i}. {title} {sentiment_text}") return "\n".join(summaries) class BearResearcherAgent(Agent): """ 看空研究员智能体 职责:基于新闻和数据,识别风险和潜在问题 支持在辩论中请求额外数据 """ def __init__(self, llm_provider=None, organization_id: str = "finnews"): # 先调用父类初始化(Pydantic BaseModel) super().__init__( name="BearResearcher", role="看空研究员", goal="从风险角度分析股票,识别潜在问题和下行风险", backstory="""你是一位谨慎的股票研究员,擅长发现被忽视的风险。 你善于从新闻和数据中提取负面信号,分析公司的潜在问题、竞争威胁和市场风险。 你的分析注重风险控制,帮助投资者避免损失。 当你发现数据不足以支撑风险判断时,你会主动请求补充数据。""", organization_id=organization_id ) # 在 super().__init__() 之后设置 _llm_provider(避免被 Pydantic 清除) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") def analyze( self, stock_code: str, stock_name: str, news_list: List[Dict[str, Any]], context: str = "" ) -> Dict[str, Any]: """ 生成看空分析报告 """ news_summary = self._summarize_news(news_list) # 获取当前系统时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") prompt = f"""你是一位看空研究员,请从风险角度分析以下股票: 【当前时间】 {current_time} 【股票信息】 代码:{stock_code} 名称:{stock_name} 【相关新闻摘要】 {news_summary} 【分析背景】 {context if context else "无额外背景信息"} 请从以下角度进行风险分析: ## 1. 核心风险因素 - 列出3-5个主要风险点 - 每个风险需要有数据或新闻支撑 ## 2. 负面催化剂 - 短期可能出现的利空事件 - 中长期的结构性风险 ## 3. 估值风险 - 当前估值是否过高 - 与同行业对比的劣势 ## 4. 下行空间 - 分析可能的下跌幅度 - 触发下跌的条件 ## 5. 反驳看多观点 - 针对常见的看多逻辑提出质疑 - 指出乐观预期的不确定性 请确保分析客观、有理有据,避免无根据的悲观。 """ try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": prompt} ]) analysis_text = response.content if hasattr(response, 'content') else str(response) return { "success": True, "agent_name": self.name, "agent_role": self.role, "stance": "bear", "analysis": analysis_text, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Bear analysis failed: {e}") return { "success": False, "agent_name": self.name, "stance": "bear", "error": str(e) } def _summarize_news(self, news_list: List[Dict[str, Any]]) -> str: """汇总新闻信息""" if not news_list: return "暂无相关新闻" summaries = [] for i, news in enumerate(news_list[:5], 1): title = news.get("title", "") sentiment = news.get("sentiment_score") sentiment_text = "" if sentiment is not None: if sentiment > 0.1: sentiment_text = "(利好)" elif sentiment < -0.1: sentiment_text = "(利空)" else: sentiment_text = "(中性)" summaries.append(f"{i}. {title} {sentiment_text}") return "\n".join(summaries) async def debate_round(self, prompt: str, enable_data_request: bool = True) -> str: """ 辩论回合发言(用于实时辩论模式) Args: prompt: 辩论提示词 enable_data_request: 是否启用数据请求功能 Returns: 发言内容(可能包含数据请求标记) """ system_content = f"""你是{self.role},{self.backstory} 你正在参与一场多空辩论,请用专业但有说服力的语气发言。 作为看空方,你的核心任务是: 1. 识别公司的潜在风险和问题 2. 用数据和事实支撑你的谨慎观点 3. 反驳看多方过于乐观的论点 4. 揭示被市场忽视的风险因素""" if enable_data_request: system_content += DATA_REQUEST_HINT try: response = self._llm_provider.invoke([ {"role": "system", "content": system_content}, {"role": "user", "content": prompt} ]) return response.content if hasattr(response, 'content') else str(response) except Exception as e: logger.error(f"Bear debate round failed: {e}") return f"[发言出错: {e}]" class InvestmentManagerAgent(Agent): """ 投资经理智能体 职责:综合 Bull/Bear 观点,做出最终投资决策 支持在决策前请求额外数据 """ def __init__(self, llm_provider=None, organization_id: str = "finnews"): # 先调用父类初始化(Pydantic BaseModel) super().__init__( name="InvestmentManager", role="投资经理", goal="综合多方观点,做出理性的投资决策", backstory="""你是一位经验丰富的投资经理,擅长在多方观点中找到平衡。 你善于综合看多和看空的分析,结合市场环境,做出最优的投资决策。 你的决策注重风险收益比,追求稳健的长期回报。 当你认为辩论双方提供的数据不足以做出决策时,你会主动请求补充关键数据。""", organization_id=organization_id ) # 在 super().__init__() 之后设置 _llm_provider(避免被 Pydantic 清除) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") def make_decision( self, stock_code: str, stock_name: str, bull_analysis: str, bear_analysis: str, context: str = "", enable_data_request: bool = False ) -> Dict[str, Any]: """ 综合双方观点,做出投资决策 Args: stock_code: 股票代码 stock_name: 股票名称 bull_analysis: 看多分析 bear_analysis: 看空分析 context: 市场背景和补充数据 enable_data_request: 是否允许请求额外数据 """ # 获取当前系统时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") prompt = f"""你是一位投资经理,请综合以下看多和看空观点,做出投资决策: 【当前时间】 {current_time} 【股票信息】 代码:{stock_code} 名称:{stock_name} 【看多观点】 {bull_analysis} 【看空观点】 {bear_analysis} 【市场背景及补充数据】 {context if context else "当前市场处于正常波动区间"} 请按以下结构给出最终决策: ## 1. 观点评估 ### 看多方论点质量 - 评估看多论点的说服力(1-10分) - 指出最有力的看多论据 - 指出看多方忽视的问题 ### 看空方论点质量 - 评估看空论点的说服力(1-10分) - 指出最有力的看空论据 - 指出看空方过于悲观的地方 ## 2. 数据充分性评估 - 辩论中使用的数据是否充分? - 是否有关键数据缺失影响决策? - 已获得的补充数据如何影响判断? ## 3. 综合判断 - 当前股票的核心矛盾是什么 - 短期(1-3个月)和中长期(6-12个月)的观点 ## 4. 投资决策 **最终评级**:[强烈推荐 / 推荐 / 中性 / 谨慎 / 回避] **决策理由**: (详细说明决策依据) **建议操作**: - 对于持仓者:持有/加仓/减仓/清仓 - 对于观望者:买入/观望/规避 **关键监测指标**: - 列出需要持续关注的信号 - 什么情况下需要调整决策 ## 5. 风险收益比 - 预期收益空间 - 潜在下行风险 - 风险收益比评估 请确保决策客观、理性,充分考虑双方观点和已获取的数据。 """ if enable_data_request: prompt += f""" 【数据请求】如果你认为还需要更多数据才能做出准确决策,可以添加搜索请求: - [SEARCH: "具体数据需求" source:akshare] - [SEARCH: "新闻关键词" source:bochaai] 但请优先基于现有数据做出判断。""" try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": prompt} ]) decision_text = response.content if hasattr(response, 'content') else str(response) # 提取评级 rating = self._extract_rating(decision_text) return { "success": True, "agent_name": self.name, "agent_role": self.role, "decision": decision_text, "rating": rating, "timestamp": datetime.utcnow().isoformat() } except Exception as e: logger.error(f"Investment decision failed: {e}") return { "success": False, "agent_name": self.name, "error": str(e) } def _extract_rating(self, text: str) -> str: """从决策文本中提取评级""" import re ratings = ["强烈推荐", "推荐", "中性", "谨慎", "回避"] for rating in ratings: if rating in text: return rating return "中性" class DebateWorkflow: """ 辩论工作流 协调 Bull/Bear/InvestmentManager 进行多轮辩论 """ def __init__(self, llm_provider=None): self.bull_agent = BullResearcherAgent(llm_provider) self.bear_agent = BearResearcherAgent(llm_provider) self.manager_agent = InvestmentManagerAgent(llm_provider) # 执行轨迹记录 self.trajectory = [] logger.info("Initialized DebateWorkflow") async def run_debate( self, stock_code: str, stock_name: str, news_list: List[Dict[str, Any]], context: str = "", rounds: int = 1 ) -> Dict[str, Any]: """ 执行完整的辩论流程 Args: stock_code: 股票代码 stock_name: 股票名称 news_list: 相关新闻列表 context: 额外上下文 rounds: 辩论轮数 Returns: 辩论结果 """ start_time = datetime.utcnow() self.trajectory = [] logger.info(f"🚀 辩论工作流开始: {stock_name}({stock_code}), 新闻数量={len(news_list)}") try: # 第一阶段:独立分析 self._log_step("debate_start", { "stock_code": stock_code, "stock_name": stock_name, "news_count": len(news_list) }) # Bull 分析 logger.info("📈 开始看多分析 (BullResearcher)...") self._log_step("bull_analysis_start", {"agent": "BullResearcher"}) bull_result = self.bull_agent.analyze(stock_code, stock_name, news_list, context) logger.info(f"📈 看多分析完成: success={bull_result.get('success', False)}") self._log_step("bull_analysis_complete", { "agent": "BullResearcher", "success": bull_result.get("success", False) }) # Bear 分析 logger.info("📉 开始看空分析 (BearResearcher)...") self._log_step("bear_analysis_start", {"agent": "BearResearcher"}) bear_result = self.bear_agent.analyze(stock_code, stock_name, news_list, context) logger.info(f"📉 看空分析完成: success={bear_result.get('success', False)}") self._log_step("bear_analysis_complete", { "agent": "BearResearcher", "success": bear_result.get("success", False) }) # 第二阶段:投资经理决策 logger.info("⚖️ 开始投资经理决策 (InvestmentManager)...") self._log_step("decision_start", {"agent": "InvestmentManager"}) decision_result = self.manager_agent.make_decision( stock_code=stock_code, stock_name=stock_name, bull_analysis=bull_result.get("analysis", ""), bear_analysis=bear_result.get("analysis", ""), context=context ) logger.info(f"⚖️ 投资经理决策完成: rating={decision_result.get('rating', 'unknown')}") self._log_step("decision_complete", { "agent": "InvestmentManager", "rating": decision_result.get("rating", "unknown") }) end_time = datetime.utcnow() execution_time = (end_time - start_time).total_seconds() logger.info(f"✅ 辩论工作流完成! 耗时={execution_time:.2f}秒, 评级={decision_result.get('rating', 'unknown')}") self._log_step("debate_complete", { "execution_time": execution_time, "final_rating": decision_result.get("rating", "unknown") }) return { "success": True, "stock_code": stock_code, "stock_name": stock_name, "bull_analysis": bull_result, "bear_analysis": bear_result, "final_decision": decision_result, "trajectory": self.trajectory, "execution_time": execution_time, "timestamp": start_time.isoformat() } except Exception as e: logger.error(f"❌ 辩论工作流失败: {e}", exc_info=True) self._log_step("debate_failed", {"error": str(e)}) return { "success": False, "error": str(e), "trajectory": self.trajectory } def _log_step(self, step_name: str, data: Dict[str, Any]): """记录执行步骤""" step = { "step": step_name, "timestamp": datetime.utcnow().isoformat(), "data": data } self.trajectory.append(step) logger.info(f"Debate step: {step_name} - {data}") # 工厂函数 def create_debate_workflow(llm_provider=None) -> DebateWorkflow: """创建辩论工作流实例""" return DebateWorkflow(llm_provider) ================================================ FILE: backend/app/agents/news_analyst.py ================================================ """ 新闻分析师智能体 """ import logging from typing import List, Dict, Any, Optional from agenticx import Agent, Task, BaseTool from agenticx.core.agent_executor import AgentExecutor from ..services.llm_service import get_llm_provider from ..tools import TextCleanerTool logger = logging.getLogger(__name__) class NewsAnalystAgent(Agent): """ 新闻分析师智能体 职责:分析金融新闻的情感、影响和关键信息 """ def __init__( self, llm_provider=None, tools: Optional[List[BaseTool]] = None, organization_id: str = "finnews", **kwargs ): """ 初始化新闻分析师智能体 Args: llm_provider: LLM 提供者 tools: 工具列表 organization_id: 组织ID(用于多租户隔离),默认 "finnews" **kwargs: 额外参数 """ # 如果没有提供 LLM,使用默认的 if llm_provider is None: llm_provider = get_llm_provider() # 如果没有提供工具,使用默认工具 if tools is None: tools = [TextCleanerTool()] # 保存 LLM 和工具供后续使用(在 super().__init__ 之前保存) self._llm_provider = llm_provider self._tools = tools # 定义智能体属性(Agent 基类) super().__init__( name="NewsAnalyst", role="金融新闻分析师", goal="深度分析金融新闻,提取关键信息,评估市场影响", backstory="""你是一位经验丰富的金融新闻分析专家,具有10年以上的证券市场分析经验。 你擅长从新闻中提取关键信息,准确判断新闻对股票市场的影响,并能够识别潜在的投资机会和风险。 你的分析报告准确、专业,深受投资者信赖。""", organization_id=organization_id, **kwargs ) # 创建 AgentExecutor(在 super().__init__ 之后) self._executor = None self._init_executor(llm_provider, tools) logger.info(f"Initialized {self.name} agent") def _init_executor(self, llm_provider=None, tools=None): """初始化 AgentExecutor(延迟初始化)""" if self._executor is None: if llm_provider is None: llm_provider = getattr(self, '_llm_provider', None) or get_llm_provider() if tools is None: tools = getattr(self, '_tools', None) or [TextCleanerTool()] self._llm_provider = llm_provider self._tools = tools self._executor = AgentExecutor( llm_provider=llm_provider, tools=tools ) @property def executor(self): """获取 AgentExecutor(延迟初始化)""" if self._executor is None: self._init_executor() return self._executor def analyze_news( self, news_title: str, news_content: str, news_url: str = "", stock_codes: List[str] = None ) -> Dict[str, Any]: """ 分析单条新闻 Args: news_title: 新闻标题 news_content: 新闻内容 news_url: 新闻URL stock_codes: 关联股票代码 Returns: 分析结果字典 """ # 构建分析提示词 prompt = f"""你是一位经验丰富的金融新闻分析专家,具有10年以上的证券市场分析经验。 你擅长从新闻中提取关键信息,准确判断新闻对股票市场的影响,并能够识别潜在的投资机会和风险。 请深度分析以下金融新闻,并提供结构化的分析报告: 【新闻标题】 {news_title} 【新闻内容】 {news_content[:2000]} 【关联股票】 {', '.join(stock_codes) if stock_codes else '无'} 请按照以下结构进行专业分析,并严格使用 Markdown 格式输出: ## 摘要 结构性分析,长期利好市场生态** ### 正面影响: - 核心要点1 - 核心要点2 - 核心要点3 ### 潜在挑战: - 挑战点1 - 挑战点2 --- ## 1. 情感倾向:[中性偏利好] (评分:X.X) **情感判断**:[中性偏利好/利好/利空/中性]** **综合评分**:+X.X (范围:-1 至 +1)** **理由说明:** 详细说明评分依据,包括: - 政策影响分析 - 市场短期/长期影响 - 预期收益/风险评估 --- ## 2. 关键信息提取 **请使用标准 Markdown 表格格式,确保表格清晰易读:** | 类别 | 内容 | |------|------| | 公司名称 | XXX公司(全称,股票代码:XXXXXX) | | 事件时间 | 新闻发布时间:YYYY年MM月DD日;关键事件时间线涵盖YYYY年QXXX | | 股价变动 | 详细描述股价变化趋势和数据 | | 财务表现(YYYY年QX) | 关键财务指标(使用具体数字和增长率) | | 驱动因素 | • 因素1
• 因素2
• 因素3 | | 分析师观点 | • 机构1(分析师):观点内容
• 机构2(分析师):观点内容 | | 市场情绪指标 | 具体指标和数据 | **重要说明(表格严格规范)**: - **禁止跨行**:同一类别下的所有内容必须在**同一行**的单元格内 - **强制换行**:如果同一单元格有多条内容,**必须**使用 `
` 分隔,**严禁**使用 Markdown 列表(- 或 1.)或直接换行 - **错误示例**(绝对禁止): | 驱动因素 | • 因素1 | | | • 因素2 | <-- 错误!不能另起一行 - **正确示例**: | 驱动因素 | • 因素1
• 因素2 | - 表头和内容之间用 `|------|------|` 分隔 - 数据要准确,有具体数字时必须标注 --- ## 3. 市场影响分析 ### 短期影响(1-3个月) - 影响点1:具体分析 - 影响点2:具体分析 ### 中期影响(3-12个月) - 影响点1:具体分析 - 影响点2:具体分析 ### 长期影响(1年以上) - 影响点1:具体分析 - 影响点2:具体分析 --- ## 4. 投资建议 **投资评级**:[推荐买入/谨慎持有/观望/减持] **建议理由**: 1. 核心逻辑1 2. 核心逻辑2 3. 核心逻辑3 **风险提示**: - 风险1 - 风险2 --- **格式要求(重要)**: 1. 必须使用标准 Markdown 语法 2. **表格内容严禁跨行**,单元格内换行只能用 `
` 3. 标题层级清晰:使用 ##、### 等 4. 列表使用 - 或数字编号(表格外) 5. 加粗使用 **文本** 6. 分隔线使用 --- 7. 评分必须精确到小数点后1位 8. 所有数据必须真实、准确,来源于新闻内容 请确保分析报告专业、准确、结构清晰,特别注意表格格式的规范性,避免表格行错位。 """ try: # 确保 LLM provider 已初始化 if not hasattr(self, '_llm_provider') or self._llm_provider is None: self._llm_provider = get_llm_provider() logger.info(f"Calling LLM provider: {type(self._llm_provider).__name__}, model: {getattr(self._llm_provider, 'model', 'unknown')}") # 直接调用 LLM(不使用 AgentExecutor,避免审批暂停) response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": prompt} ]) logger.info("LLM response received") # 获取分析结果 analysis_text = response.content if hasattr(response, 'content') else str(response) # 修复 Markdown 表格格式 analysis_text = self._repair_markdown_table(analysis_text) # 尝试提取结构化信息 structured_result = self._extract_structured_info(analysis_text) return { "success": True, "analysis_result": analysis_text, "structured_data": structured_result, "agent_name": self.name, "agent_role": self.role, } except Exception as e: logger.error(f"News analysis failed: {e}", exc_info=True) return { "success": False, "error": str(e), "agent_name": self.name, } def _repair_markdown_table(self, text: str) -> str: """ 修复 Markdown 表格格式问题 主要解决:多行内容被错误拆分为多行单元格,导致首列为空的问题 """ import re lines = text.split('\n') new_lines = [] in_table = False last_table_line_idx = -1 for line in lines: stripped = line.strip() # 检测表格行 is_table_row = stripped.startswith('|') and stripped.endswith('|') is_separator = '---' in stripped and '|' in stripped if is_table_row: if not in_table: in_table = True # 如果是分隔行,直接添加 if is_separator: new_lines.append(line) last_table_line_idx = len(new_lines) - 1 continue # 检查是否是"坏行"(首列为空) # 匹配模式:| 空白 | 内容 | parts = [p.strip() for p in stripped.strip('|').split('|')] # 如果首列为空,且不是第一行,且上一行也是表格行 if len(parts) >= 2 and not parts[0] and last_table_line_idx >= 0: # 获取上一行 prev_line = new_lines[last_table_line_idx] prev_parts = [p.strip() for p in prev_line.strip().strip('|').split('|')] # 确保列数匹配 if len(parts) == len(prev_parts): # 将内容合并到上一行的对应列 for i in range(1, len(parts)): if parts[i]: prev_parts[i] = f"{prev_parts[i]}
• {parts[i]}" if parts[i].startswith('•') else f"{prev_parts[i]}
{parts[i]}" # 重建上一行 new_prev_line = '| ' + ' | '.join(prev_parts) + ' |' new_lines[last_table_line_idx] = new_prev_line # 当前行被合并,不添加到 new_lines continue else: in_table = False new_lines.append(line) if in_table: last_table_line_idx = len(new_lines) - 1 return '\n'.join(new_lines) def _extract_structured_info(self, analysis_text: str) -> Dict[str, Any]: """ 从分析文本中提取结构化信息 Args: analysis_text: 分析文本 Returns: 结构化数据 """ import re result = { "sentiment": "neutral", "sentiment_score": 0.0, "confidence": 0.5, "key_points": [], "market_impact": "", "investment_advice": "", } try: # 提取情感倾向(支持多种格式) # 匹配:利好、利空、中性、显著利好、显著利空等 sentiment_patterns = [ r'情感倾向[::]\s*\*?\*?(显著|明显)?(利好|利空|中性)', r'(显著|明显)?(利好|利空|中性)', # 备用模式 ] for pattern in sentiment_patterns: sentiment_match = re.search(pattern, analysis_text) if sentiment_match: # 提取最后一个匹配的词(利好/利空/中性) groups = [g for g in sentiment_match.groups() if g] if groups: sentiment_word = groups[-1] sentiment_map = {"利好": "positive", "利空": "negative", "中性": "neutral"} result["sentiment"] = sentiment_map.get(sentiment_word, "neutral") break # 提取情感评分(支持多种格式) # 匹配:-0.92、**-0.92**、-0.92 / -1.0 等格式 score_patterns = [ r'综合评分[::]\s*\*?\*?([-+]?\d*\.?\d+)', # 综合评分:-0.92(优先级最高) r'评分[::]\s*\*?\*?([-+]?\d*\.?\d+)\s*/\s*[-+]?\d*\.?\d+', # 评分:-0.85 / 1.0 r'情感评分[::]\s*\*?\*?([-+]?\d*\.?\d+)', # 情感评分:-0.92 r'评分[::]\s*\*?\*?([-+]?\d*\.?\d+)', # 评分:-0.92 ] for pattern in score_patterns: score_match = re.search(pattern, analysis_text) if score_match: result["sentiment_score"] = float(score_match.group(1)) logger.info(f"Extracted sentiment score: {result['sentiment_score']}") break # 如果未提取到评分,尝试从情感倾向推断 if result["sentiment_score"] == 0.0 and result["sentiment"] != "neutral": if result["sentiment"] == "positive": result["sentiment_score"] = 0.5 # 默认中等利好 elif result["sentiment"] == "negative": result["sentiment_score"] = -0.5 # 默认中等利空 # 提取置信度 confidence_match = re.search(r'置信度[::]\s*\*?\*?(\d*\.?\d+)', analysis_text) if confidence_match: result["confidence"] = float(confidence_match.group(1)) # 提取关键信息点(简单实现:查找列表) key_points_section = re.search(r'关键信息[::](.*?)(?=市场影响|投资建议|$)', analysis_text, re.DOTALL) if key_points_section: points_text = key_points_section.group(1) points = re.findall(r'[•\-\*]\s*(.+)', points_text) result["key_points"] = [p.strip() for p in points if p.strip()] # 提取市场影响 impact_match = re.search(r'市场影响[::](.*?)(?=投资建议|置信度|$)', analysis_text, re.DOTALL) if impact_match: result["market_impact"] = impact_match.group(1).strip() # 提取投资建议 advice_match = re.search(r'投资建议[::](.*?)(?=置信度|$)', analysis_text, re.DOTALL) if advice_match: result["investment_advice"] = advice_match.group(1).strip() except Exception as e: logger.warning(f"Failed to extract structured info: {e}") # 日志记录提取结果 logger.info( f"Extracted sentiment: {result['sentiment']}, " f"score: {result['sentiment_score']}, " f"confidence: {result['confidence']}" ) return result def batch_analyze( self, news_list: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ 批量分析新闻 Args: news_list: 新闻列表 Returns: 分析结果列表 """ results = [] for news in news_list: try: result = self.analyze_news( news_title=news.get("title", ""), news_content=news.get("content", ""), news_url=news.get("url", ""), stock_codes=news.get("stock_codes", []) ) results.append(result) except Exception as e: logger.error(f"Failed to analyze news: {e}") results.append({ "success": False, "error": str(e), "news_url": news.get("url", "") }) return results def create_news_analyst( llm_provider=None, tools: Optional[List[BaseTool]] = None, organization_id: str = "finnews" ) -> NewsAnalystAgent: """ 创建新闻分析师智能体实例 Args: llm_provider: LLM 提供者 tools: 工具列表 organization_id: 组织ID(用于多租户隔离),默认 "finnews" Returns: NewsAnalystAgent 实例 """ return NewsAnalystAgent( llm_provider=llm_provider, tools=tools, organization_id=organization_id ) ================================================ FILE: backend/app/agents/orchestrator.py ================================================ """ 协作编排器 负责管理多智能体协作流程,支持: - 并行分析模式(parallel) - 实时辩论模式(realtime_debate) - 快速分析模式(quick_analysis) - 动态搜索模式(在辩论过程中按需获取数据) """ import logging import asyncio from typing import Dict, Any, List, Optional, Callable, AsyncGenerator from datetime import datetime from enum import Enum from ..config import get_mode_config, get_default_mode, DebateModeConfig from ..services.llm_service import get_llm_provider logger = logging.getLogger(__name__) class DebatePhase(Enum): """辩论阶段""" INITIALIZING = "initializing" DATA_COLLECTION = "data_collection" OPENING = "opening" DEBATE = "debate" CLOSING = "closing" COMPLETED = "completed" FAILED = "failed" class DebateEvent: """辩论事件(用于实时流式输出)""" def __init__( self, event_type: str, agent_name: str, content: str, phase: DebatePhase, round_number: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None ): self.event_type = event_type self.agent_name = agent_name self.content = content self.phase = phase self.round_number = round_number self.metadata = metadata or {} self.timestamp = datetime.utcnow().isoformat() def to_dict(self) -> Dict[str, Any]: return { "event_type": self.event_type, "agent_name": self.agent_name, "content": self.content, "phase": self.phase.value, "round_number": self.round_number, "metadata": self.metadata, "timestamp": self.timestamp } class DebateOrchestrator: """辩论编排器""" def __init__( self, mode: str = None, llm_provider=None, enable_dynamic_search: bool = True ): """ 初始化辩论编排器 Args: mode: 辩论模式 (parallel, realtime_debate, quick_analysis) llm_provider: LLM 提供者 enable_dynamic_search: 是否启用动态搜索(辩论中按需获取数据) """ self.mode = mode or get_default_mode() self.config = get_mode_config(self.mode) if not self.config: raise ValueError(f"未知的辩论模式: {self.mode}") self.llm_provider = llm_provider or get_llm_provider() self.current_phase = DebatePhase.INITIALIZING self.current_round = 0 self.start_time: Optional[datetime] = None self.events: List[DebateEvent] = [] self.is_interrupted = False # 动态搜索配置 self.enable_dynamic_search = enable_dynamic_search self._search_analyst = None # 搜索统计 self.search_stats = { "total_requests": 0, "successful_searches": 0, "data_supplements": [] } # 事件回调 self._event_callbacks: List[Callable[[DebateEvent], None]] = [] logger.info(f"🎭 初始化辩论编排器,模式: {self.mode}, 动态搜索: {enable_dynamic_search}") def _get_search_analyst(self): """懒加载搜索分析师""" if self._search_analyst is None and self.enable_dynamic_search: from .search_analyst import SearchAnalystAgent self._search_analyst = SearchAnalystAgent(self.llm_provider) return self._search_analyst def on_event(self, callback: Callable[[DebateEvent], None]): """注册事件回调""" self._event_callbacks.append(callback) def _emit_event(self, event: DebateEvent): """触发事件""" self.events.append(event) for callback in self._event_callbacks: try: callback(event) except Exception as e: logger.error(f"事件回调出错: {e}") def interrupt(self, reason: str = "manager_decision"): """打断辩论""" self.is_interrupted = True self._emit_event(DebateEvent( event_type="interrupt", agent_name="InvestmentManager", content=f"辩论被打断: {reason}", phase=self.current_phase )) logger.info(f"⚡ 辩论被打断: {reason}") async def run( self, stock_code: str, stock_name: str, context: str = "", news_list: List[Dict[str, Any]] = None ) -> Dict[str, Any]: """运行辩论流程""" self.start_time = datetime.utcnow() result = { "success": False, "mode": self.mode, "stock_code": stock_code, "stock_name": stock_name, "trajectory": [], "events": [] } try: self._emit_event(DebateEvent( event_type="start", agent_name="Orchestrator", content=f"开始 {self.config.name}", phase=DebatePhase.INITIALIZING )) # 根据模式选择执行流程 if self.config.flow.type == "parallel_then_summarize": result = await self._run_parallel_mode(stock_code, stock_name, context, news_list) elif self.config.flow.type == "orchestrated_debate": result = await self._run_realtime_debate_mode(stock_code, stock_name, context, news_list) elif self.config.flow.type == "single_agent": result = await self._run_quick_mode(stock_code, stock_name, context) else: raise ValueError(f"未知的流程类型: {self.config.flow.type}") self.current_phase = DebatePhase.COMPLETED self._emit_event(DebateEvent( event_type="complete", agent_name="Orchestrator", content="辩论完成", phase=DebatePhase.COMPLETED )) except Exception as e: logger.error(f"辩论执行失败: {e}", exc_info=True) self.current_phase = DebatePhase.FAILED result["error"] = str(e) self._emit_event(DebateEvent( event_type="error", agent_name="Orchestrator", content=f"辩论失败: {e}", phase=DebatePhase.FAILED )) result["events"] = [e.to_dict() for e in self.events] result["execution_time"] = (datetime.utcnow() - self.start_time).total_seconds() return result async def _run_parallel_mode( self, stock_code: str, stock_name: str, context: str, news_list: List[Dict[str, Any]] ) -> Dict[str, Any]: """运行并行分析模式""" from .debate_agents import BullResearcherAgent, BearResearcherAgent, InvestmentManagerAgent logger.info("🔄 执行并行分析模式") # 初始化智能体 bull_agent = BullResearcherAgent(self.llm_provider) bear_agent = BearResearcherAgent(self.llm_provider) manager_agent = InvestmentManagerAgent(self.llm_provider) # 准备新闻摘要 news_summary = self._prepare_news_summary(news_list) full_context = f"{context}\n\n{news_summary}" if context else news_summary self.current_phase = DebatePhase.DEBATE # 并行执行Bull和Bear分析 self._emit_event(DebateEvent( event_type="analysis_start", agent_name="BullResearcher", content="开始看多分析", phase=self.current_phase )) self._emit_event(DebateEvent( event_type="analysis_start", agent_name="BearResearcher", content="开始看空分析", phase=self.current_phase )) bull_task = asyncio.create_task( bull_agent.analyze(stock_code, stock_name, full_context) ) bear_task = asyncio.create_task( bear_agent.analyze(stock_code, stock_name, full_context) ) bull_analysis, bear_analysis = await asyncio.gather(bull_task, bear_task) self._emit_event(DebateEvent( event_type="analysis_complete", agent_name="BullResearcher", content=bull_analysis.get("analysis", "")[:200] + "...", phase=self.current_phase )) self._emit_event(DebateEvent( event_type="analysis_complete", agent_name="BearResearcher", content=bear_analysis.get("analysis", "")[:200] + "...", phase=self.current_phase )) # 投资经理做决策 self.current_phase = DebatePhase.CLOSING self._emit_event(DebateEvent( event_type="decision_start", agent_name="InvestmentManager", content="开始综合决策", phase=self.current_phase )) final_decision = await manager_agent.make_decision( stock_code=stock_code, stock_name=stock_name, bull_analysis=bull_analysis.get("analysis", ""), bear_analysis=bear_analysis.get("analysis", ""), context=full_context ) self._emit_event(DebateEvent( event_type="decision_complete", agent_name="InvestmentManager", content=f"决策完成: {final_decision.get('rating', 'N/A')}", phase=self.current_phase )) return { "success": True, "mode": self.mode, "bull_analysis": bull_analysis, "bear_analysis": bear_analysis, "final_decision": final_decision, "trajectory": [ {"agent": "BullResearcher", "action": "analyze", "status": "completed"}, {"agent": "BearResearcher", "action": "analyze", "status": "completed"}, {"agent": "InvestmentManager", "action": "decide", "status": "completed"} ] } async def _run_realtime_debate_mode( self, stock_code: str, stock_name: str, context: str, news_list: List[Dict[str, Any]] ) -> Dict[str, Any]: """运行实时辩论模式(支持动态搜索)""" from .debate_agents import BullResearcherAgent, BearResearcherAgent, InvestmentManagerAgent from .data_collector import DataCollectorAgent logger.info("🎭 执行实时辩论模式") # 初始化智能体 data_collector = DataCollectorAgent(self.llm_provider) bull_agent = BullResearcherAgent(self.llm_provider) bear_agent = BearResearcherAgent(self.llm_provider) manager_agent = InvestmentManagerAgent(self.llm_provider) # 获取搜索分析师(如果启用) search_analyst = self._get_search_analyst() rules = self.config.rules max_rounds = rules.max_rounds or 5 max_time = rules.max_time or 600 trajectory = [] debate_history = [] dynamic_data_supplements = [] # 记录动态搜索补充的数据 # Phase 1: 数据搜集 if rules.require_data_collection: self.current_phase = DebatePhase.DATA_COLLECTION self._emit_event(DebateEvent( event_type="phase_start", agent_name="DataCollector", content="开始搜集数据", phase=self.current_phase )) collected_data = await data_collector.collect_data(stock_code, stock_name) data_summary = collected_data.get("summary", "") self._emit_event(DebateEvent( event_type="data_collected", agent_name="DataCollector", content=data_summary[:300] + "...", phase=self.current_phase )) trajectory.append({ "agent": "DataCollector", "action": "collect_data", "status": "completed" }) # 合并数据到上下文 context = f"{context}\n\n{data_summary}" if context else data_summary # Phase 2: 投资经理开场 self.current_phase = DebatePhase.OPENING opening_prompt = f"""你是投资经理,现在要主持一场关于 {stock_name}({stock_code}) 的多空辩论。 请做开场陈述,说明: 1. 今天辩论的股票背景 2. 辩论的规则(最多{max_rounds}轮,每人每轮1分钟) 3. 请看多研究员先发言 背景资料: {context[:2000]}""" self._emit_event(DebateEvent( event_type="opening", agent_name="InvestmentManager", content="投资经理开场中...", phase=self.current_phase )) opening = await self.llm_provider.chat(opening_prompt) self._emit_event(DebateEvent( event_type="speech", agent_name="InvestmentManager", content=opening, phase=self.current_phase, round_number=0 )) trajectory.append({ "agent": "InvestmentManager", "action": "opening", "status": "completed", "content": opening }) debate_history.append({ "round": 0, "agent": "InvestmentManager", "type": "opening", "content": opening }) # Phase 3: 辩论回合 self.current_phase = DebatePhase.DEBATE bull_analysis_full = "" bear_analysis_full = "" for round_num in range(1, max_rounds + 1): if self.is_interrupted: logger.info(f"辩论在第{round_num}轮被打断") break # 检查时间限制 elapsed = (datetime.utcnow() - self.start_time).total_seconds() if elapsed > max_time: logger.info(f"辩论超时,已进行 {elapsed:.0f} 秒") break self.current_round = round_num # Bull发言 self._emit_event(DebateEvent( event_type="round_start", agent_name="BullResearcher", content=f"第{round_num}轮 - 看多研究员发言", phase=self.current_phase, round_number=round_num )) bull_prompt = self._build_debate_prompt( agent_role="看多研究员", stock_name=stock_name, stock_code=stock_code, round_num=round_num, max_rounds=max_rounds, context=context, debate_history=debate_history, enable_search_requests=self.enable_dynamic_search ) bull_response = await bull_agent.debate_round(bull_prompt) bull_analysis_full += f"\n\n### 第{round_num}轮\n{bull_response}" self._emit_event(DebateEvent( event_type="speech", agent_name="BullResearcher", content=bull_response, phase=self.current_phase, round_number=round_num )) debate_history.append({ "round": round_num, "agent": "BullResearcher", "type": "argument", "content": bull_response }) # 动态搜索:处理 Bull 发言中的数据请求 if search_analyst: context, supplement = await self._process_speech_for_search( search_analyst=search_analyst, speech_text=bull_response, agent_name="BullResearcher", stock_code=stock_code, stock_name=stock_name, context=context, round_num=round_num, trajectory=trajectory ) if supplement: dynamic_data_supplements.append(supplement) # Bear发言 self._emit_event(DebateEvent( event_type="round_continue", agent_name="BearResearcher", content=f"第{round_num}轮 - 看空研究员发言", phase=self.current_phase, round_number=round_num )) bear_prompt = self._build_debate_prompt( agent_role="看空研究员", stock_name=stock_name, stock_code=stock_code, round_num=round_num, max_rounds=max_rounds, context=context, debate_history=debate_history, enable_search_requests=self.enable_dynamic_search ) bear_response = await bear_agent.debate_round(bear_prompt) bear_analysis_full += f"\n\n### 第{round_num}轮\n{bear_response}" self._emit_event(DebateEvent( event_type="speech", agent_name="BearResearcher", content=bear_response, phase=self.current_phase, round_number=round_num )) debate_history.append({ "round": round_num, "agent": "BearResearcher", "type": "argument", "content": bear_response }) # 动态搜索:处理 Bear 发言中的数据请求 if search_analyst: context, supplement = await self._process_speech_for_search( search_analyst=search_analyst, speech_text=bear_response, agent_name="BearResearcher", stock_code=stock_code, stock_name=stock_name, context=context, round_num=round_num, trajectory=trajectory ) if supplement: dynamic_data_supplements.append(supplement) trajectory.append({ "agent": "Debate", "action": f"round_{round_num}", "status": "completed" }) # 投资经理可选择打断或请求更多数据 if rules.manager_can_interrupt and round_num < max_rounds: should_interrupt, manager_data_request = await self._check_manager_interrupt_or_search( manager_agent, debate_history, stock_name, stock_code, search_analyst, context ) # 如果经理请求了更多数据,更新上下文 if manager_data_request: context = f"{context}\n\n【投资经理补充数据】\n{manager_data_request}" dynamic_data_supplements.append({ "round": round_num, "agent": "InvestmentManager", "data": manager_data_request }) if should_interrupt: self.interrupt("投资经理认为已有足够信息做决策") break # Phase 4: 投资经理总结决策 self.current_phase = DebatePhase.CLOSING self._emit_event(DebateEvent( event_type="closing_start", agent_name="InvestmentManager", content="投资经理正在做最终决策...", phase=self.current_phase )) # 如果启用了动态搜索,在做决策前进行智能数据补充 if search_analyst and len(dynamic_data_supplements) < 2: self._emit_event(DebateEvent( event_type="smart_supplement", agent_name="SearchAnalyst", content="智能分析数据缺口,补充关键信息...", phase=self.current_phase )) smart_result = await search_analyst.smart_data_supplement( stock_code=stock_code, stock_name=stock_name, existing_context=context, debate_history=debate_history ) if smart_result.get("success") and smart_result.get("combined_summary"): context = f"{context}\n\n【智能补充数据】\n{smart_result['combined_summary']}" dynamic_data_supplements.append({ "round": "pre_decision", "agent": "SearchAnalyst", "data": smart_result["combined_summary"] }) final_decision = await manager_agent.make_decision( stock_code=stock_code, stock_name=stock_name, bull_analysis=bull_analysis_full, bear_analysis=bear_analysis_full, context=f"{context}\n\n辩论历史:\n{self._format_debate_history(debate_history)}" ) self._emit_event(DebateEvent( event_type="decision", agent_name="InvestmentManager", content=final_decision.get("summary", ""), phase=self.current_phase, metadata={"rating": final_decision.get("rating")} )) trajectory.append({ "agent": "InvestmentManager", "action": "final_decision", "status": "completed" }) return { "success": True, "mode": self.mode, "bull_analysis": {"analysis": bull_analysis_full, "success": True}, "bear_analysis": {"analysis": bear_analysis_full, "success": True}, "final_decision": final_decision, "debate_history": debate_history, "total_rounds": self.current_round, "was_interrupted": self.is_interrupted, "trajectory": trajectory, "dynamic_search_enabled": self.enable_dynamic_search, "data_supplements": dynamic_data_supplements, "search_stats": self.search_stats } async def _process_speech_for_search( self, search_analyst, speech_text: str, agent_name: str, stock_code: str, stock_name: str, context: str, round_num: int, trajectory: List[Dict] ) -> tuple: """ 处理发言中的搜索请求 Returns: (updated_context, supplement_data) """ try: result = await search_analyst.process_debate_speech( speech_text=speech_text, stock_code=stock_code, stock_name=stock_name, agent_name=agent_name ) self.search_stats["total_requests"] += result.get("requests_found", 0) if result.get("success") and result.get("combined_summary"): self.search_stats["successful_searches"] += len(result.get("search_results", [])) self._emit_event(DebateEvent( event_type="dynamic_search", agent_name="SearchAnalyst", content=f"为 {agent_name} 补充了 {result['requests_found']} 项数据", phase=self.current_phase, round_number=round_num, metadata={"requests": result["requests_found"]} )) trajectory.append({ "agent": "SearchAnalyst", "action": f"search_for_{agent_name}", "status": "completed", "requests": result["requests_found"] }) # 更新上下文 new_context = f"{context}\n\n【{agent_name} 请求的补充数据】\n{result['combined_summary']}" supplement = { "round": round_num, "agent": agent_name, "requests": result["requests_found"], "data": result["combined_summary"][:500] } return new_context, supplement except Exception as e: logger.warning(f"处理搜索请求时出错: {e}") return context, None async def _run_quick_mode( self, stock_code: str, stock_name: str, context: str ) -> Dict[str, Any]: """运行快速分析模式""" from .data_collector import QuickAnalystAgent logger.info("🚀 执行快速分析模式") quick_analyst = QuickAnalystAgent(self.llm_provider) self.current_phase = DebatePhase.DEBATE self._emit_event(DebateEvent( event_type="quick_analysis_start", agent_name="QuickAnalyst", content="开始快速分析", phase=self.current_phase )) result = await quick_analyst.quick_analyze(stock_code, stock_name, context) self._emit_event(DebateEvent( event_type="quick_analysis_complete", agent_name="QuickAnalyst", content=result.get("analysis", "")[:200] + "...", phase=self.current_phase )) return { "success": result.get("success", False), "mode": self.mode, "quick_analysis": result, "trajectory": [ {"agent": "QuickAnalyst", "action": "analyze", "status": "completed"} ] } def _prepare_news_summary(self, news_list: List[Dict[str, Any]]) -> str: """准备新闻摘要""" if not news_list: return "暂无相关新闻数据" summary_parts = ["## 相关新闻摘要\n"] for i, news in enumerate(news_list[:10], 1): title = news.get("title", "无标题") content = news.get("content", "")[:200] source = news.get("source", "未知来源") date = news.get("published_at", "") summary_parts.append(f"{i}. **{title}** ({source}, {date})\n {content}...\n") return "\n".join(summary_parts) def _build_debate_prompt( self, agent_role: str, stock_name: str, stock_code: str, round_num: int, max_rounds: int, context: str, debate_history: List[Dict], enable_search_requests: bool = False ) -> str: """构建辩论提示词""" history_text = self._format_debate_history(debate_history[-4:]) # 只取最近4条 # 基础提示词 prompt = f"""你是{agent_role},正在参与关于 {stock_name}({stock_code}) 的多空辩论。 当前是第 {round_num}/{max_rounds} 轮辩论。 背景资料: {context[:1500]} 最近的辩论历史: {history_text} 请发表你的观点(约200字): 1. 如果是第一轮,阐述你的核心论点 2. 如果不是第一轮,先反驳对方观点,再补充新论据 3. 用数据和事实支持你的论点 4. 语气专业但有说服力""" # 如果启用了动态搜索,添加搜索请求说明 if enable_search_requests: prompt += """ 【数据请求功能】 如果你在分析过程中发现缺少关键数据,可以在发言中使用以下格式请求搜索: - [SEARCH: "最新的毛利率数据" source:akshare] -- 从AkShare获取财务数据 - [SEARCH: "最近的行业新闻" source:bochaai] -- 从网络搜索新闻 - [SEARCH: "近期资金流向" source:akshare] -- 获取资金流向 - [SEARCH: "竞品对比分析"] -- 不指定来源则自动选择 搜索请求会在你发言后自动执行,数据会补充到下一轮的背景资料中。 请只在确实需要更多数据支撑论点时才使用搜索请求,每次最多1-2个。""" return prompt def _format_debate_history(self, history: List[Dict]) -> str: """格式化辩论历史""" if not history: return "(尚无辩论历史)" lines = [] for item in history: agent = item.get("agent", "Unknown") content = item.get("content", "")[:300] round_num = item.get("round", 0) lines.append(f"[第{round_num}轮 - {agent}]: {content}") return "\n\n".join(lines) async def _check_manager_interrupt( self, manager_agent, debate_history: List[Dict], stock_name: str ) -> bool: """检查投资经理是否要打断辩论""" if len(debate_history) < 4: return False check_prompt = f"""你是投资经理,正在主持关于 {stock_name} 的辩论。 目前的辩论历史: {self._format_debate_history(debate_history[-4:])} 请判断:你是否已经获得足够的信息来做出投资决策? 如果是,回复"是";如果还需要更多辩论,回复"否"。 只回复一个字。""" try: response = await self.llm_provider.chat(check_prompt) return "是" in response[:5] except Exception: return False async def _check_manager_interrupt_or_search( self, manager_agent, debate_history: List[Dict], stock_name: str, stock_code: str, search_analyst, context: str ) -> tuple: """ 检查投资经理是否要打断辩论或请求更多数据 Returns: (should_interrupt: bool, additional_data: str or None) """ if len(debate_history) < 4: return False, None # 如果没有搜索分析师,使用简单的打断检查 if not search_analyst: should_interrupt = await self._check_manager_interrupt( manager_agent, debate_history, stock_name ) return should_interrupt, None check_prompt = f"""你是投资经理,正在主持关于 {stock_name}({stock_code}) 的多空辩论。 目前的辩论历史: {self._format_debate_history(debate_history[-4:])} 请判断当前情况: 1. 如果你已经获得足够的信息做决策,回复:决策就绪 2. 如果你需要更多数据支持,使用以下格式请求: [SEARCH: "你需要的具体数据" source:数据源] 可用数据源: akshare(财务/行情), bochaai(新闻), browser(网页搜索) 请只回复"决策就绪"或搜索请求,不要添加其他内容。""" try: response = await self.llm_provider.chat(check_prompt) # 检查是否决策就绪 if "决策就绪" in response: return True, None # 检查是否有搜索请求 requests = search_analyst.extract_search_requests(response) if requests: self._emit_event(DebateEvent( event_type="manager_search_request", agent_name="InvestmentManager", content=f"投资经理请求 {len(requests)} 项补充数据", phase=self.current_phase, round_number=self.current_round )) # 执行搜索 search_result = await search_analyst.process_debate_speech( speech_text=response, stock_code=stock_code, stock_name=stock_name, agent_name="InvestmentManager" ) if search_result.get("success") and search_result.get("combined_summary"): self.search_stats["total_requests"] += len(requests) self.search_stats["successful_searches"] += len(search_result.get("search_results", [])) return False, search_result["combined_summary"] return False, None except Exception as e: logger.warning(f"检查经理决策时出错: {e}") return False, None def create_orchestrator( mode: str = None, llm_provider=None, enable_dynamic_search: bool = True ) -> DebateOrchestrator: """ 创建辩论编排器 Args: mode: 辩论模式 (parallel, realtime_debate, quick_analysis) llm_provider: LLM 提供者 enable_dynamic_search: 是否启用动态搜索 Returns: DebateOrchestrator 实例 """ return DebateOrchestrator( mode=mode, llm_provider=llm_provider, enable_dynamic_search=enable_dynamic_search ) ================================================ FILE: backend/app/agents/quantitative_agent.py ================================================ """ 量化分析智能体 负责量化因子挖掘、技术分析和量化策略生成。 集成 Alpha Mining 模块,提供自动化因子发现能力。 功能: - 因子挖掘:使用 RL 自动发现有效交易因子 - 因子评估:评估因子的预测能力和回测表现 - 技术分析:结合传统技术指标进行分析 - 策略生成:基于因子生成交易策略建议 """ import logging import asyncio from typing import Dict, Any, List, Optional from datetime import datetime import json logger = logging.getLogger(__name__) class QuantitativeAgent: """ 量化分析智能体 集成 Alpha Mining 模块,提供因子挖掘和量化分析能力。 Args: llm_provider: LLM 提供者 enable_alpha_mining: 是否启用因子挖掘 model_path: 预训练模型路径 Example: agent = QuantitativeAgent(llm_provider) result = await agent.analyze(stock_code, stock_name, market_data) """ def __init__( self, llm_provider=None, enable_alpha_mining: bool = True, model_path: Optional[str] = None ): self.llm_provider = llm_provider self.enable_alpha_mining = enable_alpha_mining self.model_path = model_path # 延迟初始化 Alpha Mining 组件 self._alpha_mining_initialized = False self._generator = None self._trainer = None self._vm = None self._evaluator = None self._market_builder = None self._sentiment_builder = None # 存储发现的因子 self.discovered_factors: List[Dict[str, Any]] = [] logger.info(f"QuantitativeAgent initialized (alpha_mining={enable_alpha_mining})") def _init_alpha_mining(self): """延迟初始化 Alpha Mining 组件""" if self._alpha_mining_initialized: return try: from ..alpha_mining import ( AlphaMiningConfig, FactorVocab, FactorVM, AlphaGenerator, AlphaTrainer, FactorEvaluator, MarketFeatureBuilder, SentimentFeatureBuilder ) config = AlphaMiningConfig() vocab = FactorVocab() self._vm = FactorVM(vocab=vocab) self._evaluator = FactorEvaluator(config=config) self._market_builder = MarketFeatureBuilder(config=config) self._sentiment_builder = SentimentFeatureBuilder(config=config) # 初始化生成器 self._generator = AlphaGenerator(vocab=vocab, config=config) # 如果有预训练模型,加载它 if self.model_path: try: self._generator = AlphaGenerator.load(self.model_path, vocab=vocab) logger.info(f"Loaded pretrained model from {self.model_path}") except Exception as e: logger.warning(f"Failed to load model: {e}") self._alpha_mining_initialized = True logger.info("Alpha Mining components initialized") except ImportError as e: logger.warning(f"Alpha Mining not available: {e}") self.enable_alpha_mining = False async def analyze( self, stock_code: str, stock_name: str, market_data: Optional[Dict[str, Any]] = None, sentiment_data: Optional[Dict[str, Any]] = None, context: str = "" ) -> Dict[str, Any]: """ 执行量化分析 Args: stock_code: 股票代码 stock_name: 股票名称 market_data: 行情数据(可选) sentiment_data: 情感数据(可选) context: 额外上下文 Returns: 分析结果字典 """ result = { "success": True, "stock_code": stock_code, "stock_name": stock_name, "timestamp": datetime.utcnow().isoformat(), "analysis_type": "quantitative", "factors_discovered": [], "technical_analysis": {}, "strategy_suggestion": "", "confidence": 0.0 } try: # 1. 因子挖掘(如果启用) if self.enable_alpha_mining: factor_result = await self._mine_factors( stock_code, stock_name, market_data, sentiment_data ) result["factors_discovered"] = factor_result.get("factors", []) result["factor_mining_stats"] = factor_result.get("stats", {}) # 2. 技术分析(使用 LLM) if self.llm_provider and market_data: tech_analysis = await self._technical_analysis( stock_code, stock_name, market_data, context ) result["technical_analysis"] = tech_analysis # 3. 生成策略建议 if self.llm_provider: strategy = await self._generate_strategy( stock_code, stock_name, result, context ) result["strategy_suggestion"] = strategy.get("suggestion", "") result["confidence"] = strategy.get("confidence", 0.0) except Exception as e: logger.error(f"Quantitative analysis failed: {e}", exc_info=True) result["success"] = False result["error"] = str(e) return result async def _mine_factors( self, stock_code: str, stock_name: str, market_data: Optional[Dict[str, Any]], sentiment_data: Optional[Dict[str, Any]] ) -> Dict[str, Any]: """执行因子挖掘""" self._init_alpha_mining() if not self._alpha_mining_initialized: return {"factors": [], "stats": {"error": "Alpha Mining not available"}} try: import torch from ..alpha_mining.utils import generate_mock_data # 准备特征数据 if market_data is not None: market_features = self._market_builder.build(market_data) time_steps = market_features.size(-1) if sentiment_data is not None: sentiment_features = self._sentiment_builder.build( sentiment_data, time_steps=time_steps ) features = self._sentiment_builder.combine_with_market( market_features, sentiment_features ) else: features = market_features returns = market_features[:, 0, :] # RET else: # 使用模拟数据 features, returns = generate_mock_data( num_samples=50, num_features=6, time_steps=252, seed=42 ) # 生成候选因子 formulas, _ = self._generator.generate(batch_size=20, max_len=8) # 评估每个因子 evaluated_factors = [] for formula in formulas: factor = self._vm.execute(formula, features) if factor is not None and factor.std() > 1e-6: try: metrics = self._evaluator.evaluate(factor, returns) evaluated_factors.append({ "formula": formula, "formula_str": self._vm.decode(formula), "sortino": metrics["sortino_ratio"], "sharpe": metrics["sharpe_ratio"], "ic": metrics["ic"], "max_drawdown": metrics["max_drawdown"] }) except Exception: continue # 按 Sortino 排序,取 top 5 evaluated_factors.sort(key=lambda x: x["sortino"], reverse=True) top_factors = evaluated_factors[:5] # 更新已发现因子 for f in top_factors: f["stock_code"] = stock_code f["discovered_at"] = datetime.utcnow().isoformat() self.discovered_factors.extend(top_factors) return { "factors": top_factors, "stats": { "generated": len(formulas), "valid": len(evaluated_factors), "top_sortino": top_factors[0]["sortino"] if top_factors else 0 } } except Exception as e: logger.error(f"Factor mining failed: {e}") return {"factors": [], "stats": {"error": str(e)}} async def _technical_analysis( self, stock_code: str, stock_name: str, market_data: Dict[str, Any], context: str ) -> Dict[str, Any]: """使用 LLM 进行技术分析""" # 提取关键指标 data_summary = self._summarize_market_data(market_data) prompt = f"""你是一位资深量化分析师,请对 {stock_name}({stock_code}) 进行技术分析。 行情数据摘要: {data_summary} {f'额外背景:{context}' if context else ''} 请分析: 1. 趋势判断(上涨/下跌/震荡) 2. 关键支撑位和阻力位 3. 技术指标信号(MA/MACD/RSI等) 4. 成交量分析 5. 短期(1周)和中期(1月)预测 请以 JSON 格式返回: {{ "trend": "上涨/下跌/震荡", "support_levels": [价格1, 价格2], "resistance_levels": [价格1, 价格2], "technical_signals": {{ "ma_signal": "看涨/看跌/中性", "macd_signal": "看涨/看跌/中性", "rsi_signal": "超买/超卖/中性" }}, "volume_analysis": "放量/缩量/正常", "short_term_outlook": "看涨/看跌/中性", "medium_term_outlook": "看涨/看跌/中性", "confidence": 0.0-1.0 }}""" try: response = await self.llm_provider.chat(prompt) # 尝试解析 JSON start = response.find("{") end = response.rfind("}") + 1 if start >= 0 and end > start: return json.loads(response[start:end]) return {"raw_analysis": response} except Exception as e: logger.warning(f"Technical analysis parsing failed: {e}") return {"error": str(e)} async def _generate_strategy( self, stock_code: str, stock_name: str, analysis_result: Dict[str, Any], context: str ) -> Dict[str, Any]: """生成交易策略建议""" factors_summary = "" if analysis_result.get("factors_discovered"): factors = analysis_result["factors_discovered"][:3] factors_summary = "发现的有效因子:\n" for i, f in enumerate(factors, 1): factors_summary += f"{i}. {f['formula_str']} (Sortino={f['sortino']:.2f}, IC={f['ic']:.3f})\n" tech_summary = "" tech = analysis_result.get("technical_analysis", {}) if tech and not tech.get("error"): tech_summary = f"""技术分析结论: - 趋势:{tech.get('trend', 'N/A')} - 短期展望:{tech.get('short_term_outlook', 'N/A')} - 中期展望:{tech.get('medium_term_outlook', 'N/A')} """ prompt = f"""你是一位量化投资顾问,请为 {stock_name}({stock_code}) 生成交易策略建议。 {factors_summary} {tech_summary} {f'额外背景:{context}' if context else ''} 请提供: 1. 总体投资建议(买入/持有/卖出/观望) 2. 建议的仓位比例(0-100%) 3. 入场/出场价位建议 4. 风险控制建议(止损/止盈) 5. 策略置信度(0-1) 请以 JSON 格式返回: {{ "suggestion": "详细策略建议(100-200字)", "action": "买入/持有/卖出/观望", "position_ratio": 0-100, "entry_price": 价格或null, "exit_price": 价格或null, "stop_loss": 价格或null, "take_profit": 价格或null, "confidence": 0.0-1.0, "risk_level": "低/中/高" }}""" try: response = await self.llm_provider.chat(prompt) start = response.find("{") end = response.rfind("}") + 1 if start >= 0 and end > start: return json.loads(response[start:end]) return {"suggestion": response, "confidence": 0.5} except Exception as e: logger.warning(f"Strategy generation failed: {e}") return {"suggestion": "策略生成失败", "confidence": 0.0, "error": str(e)} def _summarize_market_data(self, market_data: Dict[str, Any]) -> str: """摘要行情数据""" if isinstance(market_data, dict): if "close" in market_data: close = market_data["close"] if hasattr(close, "tolist"): close = close.tolist() if isinstance(close, list) and len(close) > 0: return f""" - 最新价格:{close[-1]:.2f} - 最高价(近期):{max(close[-20:]):.2f} - 最低价(近期):{min(close[-20:]):.2f} - 价格变化(5日):{((close[-1]/close[-5])-1)*100:.2f}% - 价格变化(20日):{((close[-1]/close[-20])-1)*100:.2f}% """ return "行情数据格式不支持摘要" async def evaluate_factor( self, formula_str: str, market_data: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """评估指定因子表达式""" self._init_alpha_mining() if not self._alpha_mining_initialized: return {"success": False, "error": "Alpha Mining not available"} try: import torch from ..alpha_mining.utils import generate_mock_data # 解析公式 tokens = [] parts = formula_str.replace("(", " ").replace(")", " ").replace(",", " ").split() for part in parts: part = part.strip() if not part: continue try: token = self._vm.vocab.name_to_token(part) tokens.append(token) except (ValueError, KeyError): continue if not tokens: return {"success": False, "error": "Invalid formula"} # 准备数据 if market_data is not None: features = self._market_builder.build(market_data) returns = features[:, 0, :] else: features, returns = generate_mock_data() # 执行 factor = self._vm.execute(tokens, features) if factor is None: return {"success": False, "error": "Factor execution failed"} # 评估 metrics = self._evaluator.evaluate(factor, returns) return { "success": True, "formula": formula_str, "metrics": metrics } except Exception as e: return {"success": False, "error": str(e)} def get_best_factors(self, top_k: int = 5) -> List[Dict[str, Any]]: """获取最优因子""" sorted_factors = sorted( self.discovered_factors, key=lambda x: x.get("sortino", 0), reverse=True ) return sorted_factors[:top_k] def create_quantitative_agent( llm_provider=None, enable_alpha_mining: bool = True, model_path: Optional[str] = None ) -> QuantitativeAgent: """ 创建量化分析智能体 Args: llm_provider: LLM 提供者 enable_alpha_mining: 是否启用因子挖掘 model_path: 预训练模型路径 Returns: QuantitativeAgent 实例 """ return QuantitativeAgent( llm_provider=llm_provider, enable_alpha_mining=enable_alpha_mining, model_path=model_path ) ================================================ FILE: backend/app/agents/search_analyst.py ================================================ """ 搜索分析师智能体 (SearchAnalystAgent) 负责在辩论过程中动态搜集数据,支持多种数据源: - AkShare: 财务指标、K线数据、资金流向、机构持仓 - BochaAI: 实时新闻搜索、分析师报告 - InteractiveCrawler: 多引擎网页搜索 (百度、搜狗、360等) - Knowledge Base: 历史新闻和上下文 (向量数据库) """ import logging import re import asyncio from typing import Dict, Any, List, Optional, ClassVar, Pattern from datetime import datetime from enum import Enum from agenticx.core.agent import Agent from ..services.llm_service import get_llm_provider from ..services.stock_data_service import stock_data_service from ..tools.bochaai_search import bochaai_search, SearchResult from ..tools.interactive_crawler import InteractiveCrawler logger = logging.getLogger(__name__) class SearchSource(Enum): """搜索数据源类型""" AKSHARE = "akshare" # AkShare 财务/行情数据 BOCHAAI = "bochaai" # BochaAI Web搜索 BROWSER = "browser" # 交互式浏览器搜索 KNOWLEDGE_BASE = "kb" # 内部知识库 ALL = "all" # 所有来源 class SearchAnalystAgent(Agent): """ 搜索分析师智能体 在辩论过程中被其他智能体调用,动态获取所需数据。 支持解析结构化搜索请求,并返回格式化的数据。 """ # 搜索请求的正则模式 [SEARCH: "query" source:xxx] # 使用 ClassVar 避免 Pydantic 将其视为模型字段 SEARCH_PATTERN: ClassVar[Pattern] = re.compile( r'\[SEARCH:\s*["\']([^"\']+)["\']\s*(?:source:(\w+))?\]', re.IGNORECASE ) def __init__(self, llm_provider=None, organization_id: str = "finnews"): super().__init__( name="SearchAnalyst", role="搜索分析师", goal="根据辩论中的数据需求,快速从多个数据源获取相关信息", backstory="""你是一位专业的金融数据搜索专家,精通各类金融数据源的使用。 你的职责是: 1. 解析辩论智能体的数据请求 2. 选择最合适的数据源进行查询 3. 整理并格式化数据,使其便于辩论使用 4. 对数据质量进行初步评估 你能够访问的数据源包括: - AkShare: 股票财务指标、K线行情、资金流向、机构持仓等 - BochaAI: 实时新闻搜索、财经报道 - 多引擎搜索: 百度资讯、搜狗、360等 - 内部知识库: 历史新闻和分析数据""", organization_id=organization_id ) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) # 初始化搜索工具 self._interactive_crawler = InteractiveCrawler(timeout=20) logger.info(f"✅ Initialized {self.name} agent with multi-source search capabilities") def extract_search_requests(self, text: str) -> List[Dict[str, Any]]: """ 从文本中提取搜索请求 支持格式: - [SEARCH: "query"] - [SEARCH: "query" source:akshare] - [SEARCH: "query" source:bochaai] - [SEARCH: "query" source:browser] Args: text: 包含搜索请求的文本 Returns: 搜索请求列表 [{"query": "...", "source": "..."}] """ requests = [] matches = self.SEARCH_PATTERN.findall(text) for match in matches: query = match[0].strip() source = match[1].lower() if match[1] else "all" # 验证 source valid_sources = [s.value for s in SearchSource] if source not in valid_sources: source = "all" requests.append({ "query": query, "source": source }) logger.info(f"🔍 提取搜索请求: query='{query}', source={source}") return requests async def search( self, query: str, source: str = "all", stock_code: Optional[str] = None, stock_name: Optional[str] = None, context: Optional[str] = None ) -> Dict[str, Any]: """ 执行搜索请求 Args: query: 搜索查询 source: 数据源 (akshare, bochaai, browser, kb, all) stock_code: 股票代码 (用于 akshare 查询) stock_name: 股票名称 (用于新闻搜索) context: 额外上下文 Returns: 搜索结果字典 """ logger.info(f"🔍 SearchAnalyst: 执行搜索 query='{query}', source={source}") result = { "query": query, "source": source, "timestamp": datetime.utcnow().isoformat(), "data": {}, "summary": "", "success": False } try: if source == SearchSource.AKSHARE.value or source == SearchSource.ALL.value: akshare_data = await self._search_akshare(query, stock_code) if akshare_data: result["data"]["akshare"] = akshare_data if source == SearchSource.BOCHAAI.value or source == SearchSource.ALL.value: bochaai_data = await self._search_bochaai(query, stock_name) if bochaai_data: result["data"]["bochaai"] = bochaai_data if source == SearchSource.BROWSER.value or source == SearchSource.ALL.value: browser_data = await self._search_browser(query) if browser_data: result["data"]["browser"] = browser_data if source == SearchSource.KNOWLEDGE_BASE.value or source == SearchSource.ALL.value: kb_data = await self._search_knowledge_base(query, stock_code, stock_name) if kb_data: result["data"]["knowledge_base"] = kb_data # 生成摘要 if result["data"]: result["summary"] = await self._generate_summary(query, result["data"]) result["success"] = True else: result["summary"] = f"未找到与'{query}'相关的数据" except Exception as e: logger.error(f"SearchAnalyst 搜索失败: {e}", exc_info=True) result["error"] = str(e) return result async def _search_akshare( self, query: str, stock_code: Optional[str] = None ) -> Optional[Dict[str, Any]]: """从 AkShare 获取数据""" if not stock_code: logger.debug("AkShare 搜索需要股票代码,跳过") return None data = {} query_lower = query.lower() try: # 根据查询内容决定获取哪些数据 if any(kw in query_lower for kw in ["财务", "pe", "pb", "roe", "利润", "估值", "市盈", "市净"]): financial = await stock_data_service.get_financial_indicators(stock_code) if financial: data["financial_indicators"] = financial if any(kw in query_lower for kw in ["资金", "主力", "流入", "流出", "散户", "机构"]): fund_flow = await stock_data_service.get_fund_flow(stock_code, days=10) if fund_flow: data["fund_flow"] = fund_flow if any(kw in query_lower for kw in ["行情", "价格", "涨跌", "成交", "量"]): realtime = await stock_data_service.get_realtime_quote(stock_code) if realtime: data["realtime_quote"] = realtime if any(kw in query_lower for kw in ["k线", "走势", "历史", "均线", "趋势"]): kline = await stock_data_service.get_kline_data(stock_code, period="daily", limit=30) if kline: # 只返回最近10天的简要数据 data["kline_summary"] = { "period": "daily", "count": len(kline), "latest": kline[-1] if kline else None, "recent_5": kline[-5:] if len(kline) >= 5 else kline } # 如果没有匹配到特定查询,获取综合数据 if not data: context_data = await stock_data_service.get_debate_context(stock_code) if context_data: data = context_data if data: logger.info(f"✅ AkShare 返回数据: {list(data.keys())}") return data except Exception as e: logger.warning(f"AkShare 搜索出错: {e}") return None async def _search_bochaai( self, query: str, stock_name: Optional[str] = None ) -> Optional[Dict[str, Any]]: """从 BochaAI 搜索新闻""" if not bochaai_search.is_available(): logger.debug("BochaAI 未配置,跳过") return None try: # 构建搜索查询 search_query = query if stock_name and stock_name not in query: search_query = f"{stock_name} {query}" results = bochaai_search.search( query=search_query, freshness="oneWeek", count=10 ) if results: news_list = [ { "title": r.title, "snippet": r.snippet[:200] if r.snippet else "", "url": r.url, "source": r.site_name or "unknown", "date": r.date_published or "" } for r in results ] logger.info(f"✅ BochaAI 返回 {len(news_list)} 条新闻") return {"news": news_list, "count": len(news_list)} except Exception as e: logger.warning(f"BochaAI 搜索出错: {e}") return None async def _search_browser(self, query: str) -> Optional[Dict[str, Any]]: """使用交互式爬虫搜索""" try: loop = asyncio.get_event_loop() results = await loop.run_in_executor( None, lambda: self._interactive_crawler.interactive_search( query=query, engines=["baidu_news", "sogou"], num_results=10, search_type="news" ) ) if results: news_list = [ { "title": r.get("title", ""), "snippet": r.get("snippet", "")[:200], "url": r.get("url", ""), "source": "browser_search" } for r in results ] logger.info(f"✅ Browser 返回 {len(news_list)} 条结果") return {"search_results": news_list, "count": len(news_list)} except Exception as e: logger.warning(f"Browser 搜索出错: {e}") return None async def _search_knowledge_base( self, query: str, stock_code: Optional[str] = None, stock_name: Optional[str] = None ) -> Optional[Dict[str, Any]]: """从知识库搜索历史数据""" try: # 尝试导入 news_service(可能不存在) try: from ..services.news_service import news_service except ImportError: logger.debug("news_service 未配置,跳过知识库搜索") return None # 尝试从数据库获取相关新闻 if stock_code and news_service: news_list = await news_service.get_news_by_stock(stock_code, limit=10) if news_list: kb_news = [ { "title": getattr(news, 'title', ''), "content": (getattr(news, 'content', '') or '')[:300], "source": getattr(news, 'source', ''), "date": news.published_at.isoformat() if hasattr(news, 'published_at') and news.published_at else "" } for news in news_list ] logger.info(f"✅ KB 返回 {len(kb_news)} 条历史新闻") return {"historical_news": kb_news, "count": len(kb_news)} except Exception as e: logger.debug(f"KB 搜索出错: {e}") return None async def _generate_summary(self, query: str, data: Dict[str, Any]) -> str: """生成数据摘要""" summary_parts = [f"## 搜索结果: {query}\n"] # AkShare 数据摘要 if "akshare" in data: ak_data = data["akshare"] summary_parts.append("### 📊 财务/行情数据 (AkShare)\n") if "financial_indicators" in ak_data: fi = ak_data["financial_indicators"] summary_parts.append(f"- PE: {fi.get('pe_ratio', 'N/A')}, PB: {fi.get('pb_ratio', 'N/A')}") summary_parts.append(f"- ROE: {fi.get('roe', 'N/A')}%, 净利润同比: {fi.get('profit_yoy', 'N/A')}%") if "realtime_quote" in ak_data: rt = ak_data["realtime_quote"] summary_parts.append(f"- 当前价: {rt.get('price', 'N/A')}元, 涨跌幅: {rt.get('change_percent', 'N/A')}%") if "fund_flow" in ak_data: ff = ak_data["fund_flow"] main_net = ff.get('total_main_net', 0) trend = ff.get('main_flow_trend', 'N/A') summary_parts.append(f"- 资金流向: 近{ff.get('period_days', 5)}日主力{trend}") summary_parts.append("") # BochaAI 新闻摘要 if "bochaai" in data: news = data["bochaai"].get("news", []) if news: summary_parts.append("### 📰 最新新闻 (BochaAI)\n") for i, n in enumerate(news[:5], 1): summary_parts.append(f"{i}. **{n['title'][:50]}**") if n.get('snippet'): summary_parts.append(f" {n['snippet'][:100]}...") summary_parts.append("") # Browser 搜索结果摘要 if "browser" in data: results = data["browser"].get("search_results", []) if results: summary_parts.append("### 🌐 网页搜索结果\n") for i, r in enumerate(results[:5], 1): summary_parts.append(f"{i}. {r['title'][:50]}") summary_parts.append("") # KB 历史数据摘要 if "knowledge_base" in data: kb_news = data["knowledge_base"].get("historical_news", []) if kb_news: summary_parts.append("### 📚 历史资料 (知识库)\n") for i, n in enumerate(kb_news[:3], 1): summary_parts.append(f"{i}. {n['title'][:50]}") summary_parts.append("") return "\n".join(summary_parts) async def process_debate_speech( self, speech_text: str, stock_code: str, stock_name: str, agent_name: str = "Unknown" ) -> Dict[str, Any]: """ 处理辩论发言中的搜索请求 Args: speech_text: 辩论发言文本 stock_code: 股票代码 stock_name: 股票名称 agent_name: 发言智能体名称 Returns: 处理结果,包含所有搜索结果和综合摘要 """ logger.info(f"🔍 SearchAnalyst: 处理 {agent_name} 的发言,检测搜索请求...") result = { "agent_name": agent_name, "requests_found": 0, "search_results": [], "combined_summary": "", "success": False } # 提取搜索请求 requests = self.extract_search_requests(speech_text) result["requests_found"] = len(requests) if not requests: logger.info(f"📝 {agent_name} 的发言中未包含搜索请求") result["success"] = True return result logger.info(f"📋 从 {agent_name} 的发言中提取到 {len(requests)} 个搜索请求") # 并行执行所有搜索 search_tasks = [] for req in requests: task = self.search( query=req["query"], source=req["source"], stock_code=stock_code, stock_name=stock_name ) search_tasks.append(task) search_results = await asyncio.gather(*search_tasks, return_exceptions=True) # 收集结果 summaries = [] for i, res in enumerate(search_results): if isinstance(res, Exception): logger.error(f"搜索请求 {i+1} 失败: {res}") continue if res.get("success"): result["search_results"].append(res) summaries.append(res.get("summary", "")) # 生成综合摘要 if summaries: result["combined_summary"] = "\n---\n".join(summaries) result["success"] = True logger.info(f"✅ SearchAnalyst: 为 {agent_name} 完成 {len(result['search_results'])} 个搜索请求") return result async def smart_data_supplement( self, stock_code: str, stock_name: str, existing_context: str, debate_history: List[Dict[str, Any]] ) -> Dict[str, Any]: """ 智能数据补充 分析辩论历史和现有上下文,主动识别缺失的关键数据并补充 Args: stock_code: 股票代码 stock_name: 股票名称 existing_context: 现有上下文 debate_history: 辩论历史 Returns: 补充的数据和摘要 """ logger.info(f"🧠 SearchAnalyst: 智能分析数据缺口...") # 使用 LLM 分析需要什么数据 analysis_prompt = f"""你是一位金融数据分析专家。请分析以下辩论情况,判断还需要哪些数据支撑: 【股票】{stock_name} ({stock_code}) 【现有数据】 {existing_context[:1500]} 【辩论历史】 {self._format_debate_history(debate_history[-4:])} 请判断: 1. 看多方缺少什么关键数据? 2. 看空方缺少什么关键数据? 3. 还需要搜索什么信息? 请按以下格式输出需要搜索的内容(每行一个): [SEARCH: "搜索内容" source:数据源] 可用数据源:akshare(财务/行情), bochaai(新闻), browser(网页搜索) 只输出3-5个最关键的搜索请求。""" try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": analysis_prompt} ]) llm_response = response.content if hasattr(response, 'content') else str(response) # 处理 LLM 建议的搜索 return await self.process_debate_speech( speech_text=llm_response, stock_code=stock_code, stock_name=stock_name, agent_name="SmartSupplement" ) except Exception as e: logger.error(f"智能数据补充失败: {e}") return {"success": False, "error": str(e)} def _format_debate_history(self, history: List[Dict[str, Any]]) -> str: """格式化辩论历史""" if not history: return "(暂无辩论历史)" lines = [] for item in history: agent = item.get("agent", "Unknown") content = item.get("content", "")[:300] lines.append(f"[{agent}]: {content}") return "\n\n".join(lines) # 工厂函数 def create_search_analyst(llm_provider=None) -> SearchAnalystAgent: """创建搜索分析师实例""" return SearchAnalystAgent(llm_provider=llm_provider) ================================================ FILE: backend/app/alpha_mining/README.md ================================================ # M12: Alpha Mining 量化因子挖掘模块 基于 AlphaGPT 技术的量化因子自动挖掘模块,使用符号回归 + 强化学习自动发现有预测能力的交易因子。 ## 功能特性 - **因子自动发现**:使用 Transformer + RL 自动生成和优化因子表达式 - **DSL 表达式系统**:支持丰富的时序操作符(MA、STD、DELAY、DELTA 等) - **情感特征融合**:可结合新闻情感分析提升因子效果 - **回测评估**:内置 Sortino/Sharpe/IC 等多种评估指标 - **AgenticX 集成**:提供 BaseTool 封装,供 Agent 调用 ## 模块结构 ``` alpha_mining/ ├── __init__.py # 模块入口 ├── config.py # 配置管理 ├── utils.py # 工具函数 ├── dsl/ # 因子表达式 DSL │ ├── ops.py # 操作符定义 │ └── vocab.py # 词汇表管理 ├── vm/ # 因子执行器 │ └── factor_vm.py # 栈式虚拟机 ├── model/ # 生成模型 │ ├── alpha_generator.py # Transformer 策略网络 │ └── trainer.py # RL 训练器 ├── features/ # 特征构建 │ ├── market.py # 行情特征 │ └── sentiment.py # 情感特征 ├── backtest/ # 回测评估 │ └── evaluator.py # 因子评估器 └── tools/ # AgenticX 工具 └── alpha_mining_tool.py ``` ## 快速开始 ### 基础使用 ```python from app.alpha_mining import ( AlphaGenerator, AlphaTrainer, FactorVM, FactorEvaluator, generate_mock_data ) # 1. 准备数据 features, returns = generate_mock_data( num_samples=50, num_features=6, time_steps=252 ) # 2. 创建训练器 trainer = AlphaTrainer() # 3. 训练挖掘因子 result = trainer.train( features=features, returns=returns, num_steps=100 ) print(f"最优因子: {result['best_formula_str']}") print(f"得分: {result['best_score']:.4f}") ``` ### 使用 QuantitativeAgent ```python from app.agents import QuantitativeAgent # 创建智能体 agent = QuantitativeAgent( llm_provider=llm, enable_alpha_mining=True ) # 执行分析 result = await agent.analyze( stock_code="000001", stock_name="平安银行", market_data=market_data, sentiment_data=sentiment_data ) # 获取发现的因子 for factor in result["factors_discovered"]: print(f"{factor['formula_str']}: Sortino={factor['sortino']:.2f}") ``` ### REST API ```bash # 启动因子挖掘任务 curl -X POST http://localhost:8000/api/v1/alpha-mining/mine \ -H "Content-Type: application/json" \ -d '{"num_steps": 100, "use_sentiment": true}' # 评估因子表达式 curl -X POST http://localhost:8000/api/v1/alpha-mining/evaluate \ -H "Content-Type: application/json" \ -d '{"formula": "ADD(RET, MA5(VOL))"}' # 获取已发现的因子 curl http://localhost:8000/api/v1/alpha-mining/factors?top_k=10 ``` ## 支持的操作符 ### 算术操作符 | 操作符 | 参数数 | 描述 | |--------|--------|------| | ADD | 2 | 加法 | | SUB | 2 | 减法 | | MUL | 2 | 乘法 | | DIV | 2 | 除法 | ### 一元操作符 | 操作符 | 参数数 | 描述 | |--------|--------|------| | NEG | 1 | 取负 | | ABS | 1 | 绝对值 | | SIGN | 1 | 符号函数 | ### 时序操作符 | 操作符 | 参数数 | 描述 | |--------|--------|------| | DELAY1/5 | 1 | 延迟 1/5 期 | | DELTA1/5 | 1 | 差分 1/5 期 | | MA5/10 | 1 | 5/10 日移动平均 | | STD5/10 | 1 | 5/10 日滚动标准差 | ### 条件操作符 | 操作符 | 参数数 | 描述 | |--------|--------|------| | GATE | 3 | 条件选择 | | MAX | 2 | 取最大值 | | MIN | 2 | 取最小值 | ## 特征列表 | 特征 | 描述 | 数据来源 | |------|------|----------| | RET | 收益率 | 行情数据 | | VOL | 波动率 | 行情数据 | | VOLUME_CHG | 成交量变化 | 行情数据 | | TURNOVER | 换手率 | 行情数据 | | SENTIMENT | 情感分数 | 新闻分析 | | NEWS_COUNT | 新闻数量 | 新闻分析 | ## 评估指标 - **Sortino Ratio**: 风险调整收益(只考虑下行风险) - **Sharpe Ratio**: 风险调整收益 - **IC**: 信息系数(因子与收益的相关性) - **Rank IC**: 排名信息系数 - **Max Drawdown**: 最大回撤 - **Turnover**: 换手率 ## 配置选项 ```python from app.alpha_mining import AlphaMiningConfig config = AlphaMiningConfig( # 模型参数 d_model=64, # Transformer 隐藏维度 num_layers=2, # Transformer 层数 nhead=4, # 注意力头数 max_seq_len=12, # 最大序列长度 # 训练参数 batch_size=1024, # 批量大小 lr=1e-3, # 学习率 num_steps=1000, # 训练步数 # 奖励参数 invalid_formula_reward=-5.0, # 无效公式惩罚 constant_factor_reward=-2.0, # 常量因子惩罚 # 回测参数 cost_rate=0.0015, # 交易成本率 signal_threshold=0.7, # 信号阈值 # 特征配置 enable_sentiment=True, # 启用情感特征 ) ``` ## 参考 - [AlphaGPT](https://github.com/imbue-bit/AlphaGPT) - 原始实现 ================================================ FILE: backend/app/alpha_mining/__init__.py ================================================ """ M12: Alpha Mining Module for FinnewsHunter 基于 AlphaGPT 技术的量化因子自动挖掘模块。 使用符号回归 + 强化学习自动发现有预测能力的交易因子。 核心组件: - dsl: 因子表达式 DSL(操作符、词汇表) - vm: 因子执行器(StackVM) - model: 因子生成模型(AlphaGenerator)和训练器(AlphaTrainer) - features: 特征构建器(行情、情感) - backtest: 因子回测评估 - tools: AgenticX 工具封装 References: - AlphaGPT: https://github.com/imbue-bit/AlphaGPT - 技术方案: researches/AlphaGPT/AlphaGPT_proposal.md """ __version__ = "0.1.0" __author__ = "FinnewsHunter Team" from .config import AlphaMiningConfig, DEFAULT_CONFIG from .dsl.vocab import FactorVocab, DEFAULT_VOCAB from .dsl.ops import OPS_CONFIG from .vm.factor_vm import FactorVM from .model.alpha_generator import AlphaGenerator from .model.trainer import AlphaTrainer from .features.market import MarketFeatureBuilder from .features.sentiment import SentimentFeatureBuilder from .backtest.evaluator import FactorEvaluator from .utils import generate_mock_data __all__ = [ # Config "AlphaMiningConfig", "DEFAULT_CONFIG", # DSL "FactorVocab", "DEFAULT_VOCAB", "OPS_CONFIG", # VM "FactorVM", # Model "AlphaGenerator", "AlphaTrainer", # Features "MarketFeatureBuilder", "SentimentFeatureBuilder", # Backtest "FactorEvaluator", # Utils "generate_mock_data", ] ================================================ FILE: backend/app/alpha_mining/backtest/__init__.py ================================================ """ 因子回测评估模块 提供因子有效性评估,包括 Sortino Ratio 等指标计算。 """ from .evaluator import FactorEvaluator __all__ = ["FactorEvaluator"] ================================================ FILE: backend/app/alpha_mining/backtest/evaluator.py ================================================ """ 因子回测评估器 评估因子的预测能力和交易表现。 评估指标: - Sortino Ratio: 风险调整收益(只考虑下行风险) - Sharpe Ratio: 风险调整收益 - IC: 信息系数(因子与收益的相关性) - Rank IC: 排名信息系数 - Turnover: 换手率 - Max Drawdown: 最大回撤 """ import torch from typing import Dict, Optional, List, Tuple import numpy as np import logging from ..config import AlphaMiningConfig, DEFAULT_CONFIG logger = logging.getLogger(__name__) class FactorEvaluator: """ 因子回测评估器 评估因子表达式的有效性和收益表现。 Args: config: 配置实例 cost_rate: 交易成本率 signal_threshold: 信号阈值(用于生成持仓) Example: evaluator = FactorEvaluator() metrics = evaluator.evaluate(factor, returns) """ def __init__( self, config: Optional[AlphaMiningConfig] = None, cost_rate: Optional[float] = None, signal_threshold: Optional[float] = None ): self.config = config or DEFAULT_CONFIG self.cost_rate = cost_rate or self.config.cost_rate self.signal_threshold = signal_threshold or self.config.signal_threshold # 年化系数(假设 252 个交易日) self.annualize_factor = np.sqrt(252) logger.info( f"FactorEvaluator initialized: " f"cost_rate={self.cost_rate}, threshold={self.signal_threshold}" ) def evaluate( self, factor: torch.Tensor, returns: torch.Tensor, benchmark: Optional[torch.Tensor] = None ) -> Dict[str, float]: """ 综合评估因子 Args: factor: 因子值 [batch, time_steps] 或 [time_steps] returns: 收益率 [batch, time_steps] 或 [time_steps] benchmark: 基准收益率(可选) Returns: 评估指标字典 """ # 确保是 2D if factor.dim() == 1: factor = factor.unsqueeze(0) if returns.dim() == 1: returns = returns.unsqueeze(0) # 对每个样本计算指标,然后平均 metrics_list = [] for i in range(factor.size(0)): f = factor[i] r = returns[i] b = benchmark[i] if benchmark is not None else None m = self._evaluate_single(f, r, b) metrics_list.append(m) # 聚合指标 result = {} for key in metrics_list[0].keys(): values = [m[key] for m in metrics_list] result[key] = np.mean(values) result[f"{key}_std"] = np.std(values) return result def _evaluate_single( self, factor: torch.Tensor, returns: torch.Tensor, benchmark: Optional[torch.Tensor] = None ) -> Dict[str, float]: """评估单个样本""" # 转换为 numpy factor_np = factor.detach().cpu().numpy() returns_np = returns.detach().cpu().numpy() # 生成信号和持仓 signal = self._factor_to_signal(factor_np) position = self._signal_to_position(signal) # 计算策略收益 strategy_returns = position[:-1] * returns_np[1:] # 计算交易成本 turnover = np.abs(np.diff(position)).mean() net_returns = strategy_returns - turnover * self.cost_rate # 计算各指标 metrics = { "sortino_ratio": self._calc_sortino(net_returns), "sharpe_ratio": self._calc_sharpe(net_returns), "ic": self._calc_ic(factor_np, returns_np), "rank_ic": self._calc_rank_ic(factor_np, returns_np), "turnover": turnover, "max_drawdown": self._calc_max_drawdown(net_returns), "total_return": np.sum(net_returns), "win_rate": np.mean(net_returns > 0), "avg_return": np.mean(net_returns), } # 相对基准的超额收益 if benchmark is not None: benchmark_np = benchmark.detach().cpu().numpy() excess_returns = net_returns - benchmark_np[1:] metrics["excess_return"] = np.sum(excess_returns) metrics["information_ratio"] = self._calc_sharpe(excess_returns) return metrics def _factor_to_signal(self, factor: np.ndarray) -> np.ndarray: """因子值转换为信号(-1 到 1)""" # 使用 Z-score 标准化 mean = np.mean(factor) std = np.std(factor) + 1e-8 z_score = (factor - mean) / std # Sigmoid 映射到 (-1, 1) signal = 2 / (1 + np.exp(-z_score)) - 1 return signal def _signal_to_position(self, signal: np.ndarray) -> np.ndarray: """信号转换为持仓""" position = np.zeros_like(signal) # 信号大于阈值时做多 position[signal > self.signal_threshold] = 1.0 # 信号小于负阈值时做空 position[signal < -self.signal_threshold] = -1.0 # 中间区域不持仓 return position def _calc_sortino(self, returns: np.ndarray) -> float: """ 计算 Sortino Ratio 只考虑下行风险(负收益的标准差) """ mean_return = np.mean(returns) downside = returns[returns < 0] if len(downside) == 0: return float('inf') if mean_return > 0 else 0.0 downside_std = np.std(downside) + 1e-8 sortino = mean_return / downside_std * self.annualize_factor return float(sortino) def _calc_sharpe(self, returns: np.ndarray) -> float: """计算 Sharpe Ratio""" mean_return = np.mean(returns) std_return = np.std(returns) + 1e-8 sharpe = mean_return / std_return * self.annualize_factor return float(sharpe) def _calc_ic(self, factor: np.ndarray, returns: np.ndarray) -> float: """ 计算 IC (Information Coefficient) 因子值与下期收益的 Pearson 相关系数 """ # 对齐:factor[t] 预测 returns[t+1] factor_lag = factor[:-1] returns_lead = returns[1:] # Pearson 相关 corr = np.corrcoef(factor_lag, returns_lead)[0, 1] return float(corr) if not np.isnan(corr) else 0.0 def _calc_rank_ic(self, factor: np.ndarray, returns: np.ndarray) -> float: """ 计算 Rank IC 因子排名与收益排名的 Spearman 相关系数 """ from scipy.stats import spearmanr factor_lag = factor[:-1] returns_lead = returns[1:] try: corr, _ = spearmanr(factor_lag, returns_lead) return float(corr) if not np.isnan(corr) else 0.0 except Exception: return 0.0 def _calc_max_drawdown(self, returns: np.ndarray) -> float: """计算最大回撤""" cumulative = np.cumsum(returns) running_max = np.maximum.accumulate(cumulative) drawdown = running_max - cumulative max_dd = np.max(drawdown) return float(max_dd) def get_reward( self, factor: torch.Tensor, returns: torch.Tensor ) -> float: """ 获取强化学习奖励 使用 Sortino Ratio 作为奖励信号。 Args: factor: 因子值 returns: 收益率 Returns: 奖励值 """ metrics = self.evaluate(factor, returns) # 主要使用 Sortino Ratio reward = metrics["sortino_ratio"] # 惩罚过高的换手率 if metrics["turnover"] > 0.5: reward -= (metrics["turnover"] - 0.5) * 2 # 惩罚过大的最大回撤 if metrics["max_drawdown"] > 0.2: reward -= (metrics["max_drawdown"] - 0.2) * 5 return float(reward) def compare_factors( self, factors: List[torch.Tensor], returns: torch.Tensor, factor_names: Optional[List[str]] = None ) -> Dict[str, Dict[str, float]]: """ 比较多个因子的表现 Args: factors: 因子列表 returns: 收益率 factor_names: 因子名称列表 Returns: {factor_name: metrics_dict} """ if factor_names is None: factor_names = [f"factor_{i}" for i in range(len(factors))] results = {} for name, factor in zip(factor_names, factors): results[name] = self.evaluate(factor, returns) return results def rank_factors( self, factors: List[torch.Tensor], returns: torch.Tensor, metric: str = "sortino_ratio" ) -> List[Tuple[int, float]]: """ 对因子按指定指标排名 Args: factors: 因子列表 returns: 收益率 metric: 排名指标 Returns: [(index, score), ...] 按 score 降序排列 """ scores = [] for i, factor in enumerate(factors): metrics = self.evaluate(factor, returns) scores.append((i, metrics.get(metric, 0))) # 降序排列 scores.sort(key=lambda x: x[1], reverse=True) return scores ================================================ FILE: backend/app/alpha_mining/config.py ================================================ """ Alpha Mining 配置模块 定义训练、模型、回测等配置参数。 References: - AlphaGPT upstream/model_core/config.py """ import torch from dataclasses import dataclass, field from typing import List, Optional @dataclass class AlphaMiningConfig: """Alpha Mining 模块配置""" # ============ 设备配置 ============ device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu") # ============ 模型配置 ============ d_model: int = 64 # Transformer 嵌入维度 nhead: int = 4 # 注意力头数 num_layers: int = 2 # Transformer 层数 max_seq_len: int = 12 # 最大因子表达式长度 # ============ 训练配置 ============ batch_size: int = 1024 # 批量大小(每批生成的因子数) num_steps: int = 1000 # 训练步数 lr: float = 1e-3 # 学习率 # ============ 奖励配置 ============ invalid_formula_reward: float = -5.0 # 无效公式惩罚 constant_factor_reward: float = -2.0 # 常量因子惩罚 low_activity_reward: float = -10.0 # 低活跃度惩罚 constant_threshold: float = 1e-4 # 常量因子阈值(std < 此值视为常量) # ============ 回测配置 ============ cost_rate: float = 0.0015 # A股交易费率(双边约0.3%) signal_threshold: float = 0.7 # 信号阈值(factor > threshold 时建仓) min_holding_days: int = 1 # 最小持仓天数 min_activity: int = 5 # 最小活跃度(持仓天数) # ============ 特征配置 ============ market_features: List[str] = field(default_factory=lambda: [ "RET", # 收益率 "VOL", # 波动率 "VOLUME_CHG", # 成交量变化 "TURNOVER", # 换手率 ]) sentiment_features: List[str] = field(default_factory=lambda: [ "SENTIMENT", # 情感分数 "NEWS_COUNT", # 新闻数量 ]) enable_sentiment: bool = True # 是否启用情感特征 # ============ 持久化配置 ============ checkpoint_dir: str = "checkpoints/alpha_mining" save_every_n_steps: int = 100 @property def torch_device(self) -> torch.device: """获取 torch.device 对象""" return torch.device(self.device) @property def all_features(self) -> List[str]: """获取所有启用的特征列表""" features = self.market_features.copy() if self.enable_sentiment: features.extend(self.sentiment_features) return features @property def num_features(self) -> int: """特征数量""" return len(self.all_features) # 默认配置实例 DEFAULT_CONFIG = AlphaMiningConfig() ================================================ FILE: backend/app/alpha_mining/dsl/__init__.py ================================================ """ 因子表达式 DSL(Domain Specific Language) 包含操作符定义和词汇表管理。 """ from .ops import OPS_CONFIG, ts_delay, ts_delta, ts_mean, ts_std from .vocab import FactorVocab, FEATURES __all__ = [ "OPS_CONFIG", "ts_delay", "ts_delta", "ts_mean", "ts_std", "FactorVocab", "FEATURES", ] ================================================ FILE: backend/app/alpha_mining/dsl/ops.py ================================================ """ 因子操作符定义 定义因子表达式中可用的操作符,包括: - 算术运算:ADD, SUB, MUL, DIV - 一元运算:NEG, ABS, SIGN - 时序运算:DELAY, DELTA, MA, STD - 条件运算:GATE, MAX, MIN References: - AlphaGPT upstream/model_core/ops.py """ import torch from typing import Callable, Tuple, List # ============================================================================ # 时序操作函数(优化版本,支持 JIT 编译) # ============================================================================ def ts_delay(x: torch.Tensor, d: int = 1) -> torch.Tensor: """ 时序延迟:将序列向右移动 d 步 Args: x: [batch, time_steps] 输入张量 d: 延迟步数 Returns: 延迟后的张量,前 d 个位置填充 0 """ if d == 0: return x if d < 0: raise ValueError(f"Delay must be non-negative, got {d}") batch_size = x.shape[0] pad = torch.zeros((batch_size, d), device=x.device, dtype=x.dtype) return torch.cat([pad, x[:, :-d]], dim=1) def ts_delta(x: torch.Tensor, d: int = 1) -> torch.Tensor: """ 时序差分:计算 x[t] - x[t-d] Args: x: [batch, time_steps] 输入张量 d: 差分步数 Returns: 差分后的张量 """ return x - ts_delay(x, d) def ts_mean(x: torch.Tensor, window: int = 5) -> torch.Tensor: """ 滑动平均 Args: x: [batch, time_steps] 输入张量 window: 窗口大小 Returns: 滑动平均后的张量 """ if window <= 0: raise ValueError(f"Window must be positive, got {window}") # 使用 unfold 实现滑动窗口 batch_size, time_steps = x.shape # Padding pad = torch.zeros((batch_size, window - 1), device=x.device, dtype=x.dtype) x_padded = torch.cat([pad, x], dim=1) # 滑动窗口平均 result = x_padded.unfold(1, window, 1).mean(dim=-1) return result def ts_std(x: torch.Tensor, window: int = 5) -> torch.Tensor: """ 滑动标准差 Args: x: [batch, time_steps] 输入张量 window: 窗口大小 Returns: 滑动标准差后的张量 """ if window <= 0: raise ValueError(f"Window must be positive, got {window}") batch_size, time_steps = x.shape # Padding pad = torch.zeros((batch_size, window - 1), device=x.device, dtype=x.dtype) x_padded = torch.cat([pad, x], dim=1) # 滑动窗口标准差 result = x_padded.unfold(1, window, 1).std(dim=-1) return result def _op_gate(condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ 条件选择:condition > 0 时返回 x,否则返回 y 类似于 torch.where(condition > 0, x, y) """ mask = (condition > 0).float() return mask * x + (1.0 - mask) * y def _op_jump(x: torch.Tensor) -> torch.Tensor: """ 跳跃检测:返回超过 3 sigma 的异常值 用于检测价格跳跃/异常波动 """ mean = x.mean(dim=1, keepdim=True) std = x.std(dim=1, keepdim=True) + 1e-6 z = (x - mean) / std return torch.relu(z - 3.0) def _op_decay(x: torch.Tensor) -> torch.Tensor: """ 衰减加权:x + 0.8*x[-1] + 0.6*x[-2] 给近期数据更高权重 """ return x + 0.8 * ts_delay(x, 1) + 0.6 * ts_delay(x, 2) def _op_max3(x: torch.Tensor) -> torch.Tensor: """ 3 期最大值 """ return torch.max(x, torch.max(ts_delay(x, 1), ts_delay(x, 2))) # ============================================================================ # 操作符配置 # ============================================================================ # 操作符配置格式:(name, function, arity) # - name: 操作符名称 # - function: 操作符函数 # - arity: 参数数量(1=一元,2=二元,3=三元) OPS_CONFIG: List[Tuple[str, Callable, int]] = [ # 二元算术运算 ('ADD', lambda x, y: x + y, 2), ('SUB', lambda x, y: x - y, 2), ('MUL', lambda x, y: x * y, 2), ('DIV', lambda x, y: x / (y + 1e-6), 2), # 安全除法 # 一元运算 ('NEG', lambda x: -x, 1), ('ABS', torch.abs, 1), ('SIGN', torch.sign, 1), # 条件运算 ('GATE', _op_gate, 3), # 条件选择 ('MAX', lambda x, y: torch.max(x, y), 2), ('MIN', lambda x, y: torch.min(x, y), 2), # 时序运算 ('DELAY1', lambda x: ts_delay(x, 1), 1), ('DELAY5', lambda x: ts_delay(x, 5), 1), ('DELTA1', lambda x: ts_delta(x, 1), 1), ('DELTA5', lambda x: ts_delta(x, 5), 1), ('MA5', lambda x: ts_mean(x, 5), 1), ('MA10', lambda x: ts_mean(x, 10), 1), ('STD5', lambda x: ts_std(x, 5), 1), ('STD10', lambda x: ts_std(x, 10), 1), # 特殊运算 ('JUMP', _op_jump, 1), ('DECAY', _op_decay, 1), ('MAX3', _op_max3, 1), ] def get_op_names() -> List[str]: """获取所有操作符名称""" return [op[0] for op in OPS_CONFIG] def get_op_by_name(name: str) -> Tuple[Callable, int]: """ 根据名称获取操作符函数和参数数量 Args: name: 操作符名称 Returns: (function, arity) 元组 Raises: KeyError: 如果操作符不存在 """ for op_name, func, arity in OPS_CONFIG: if op_name == name: return func, arity raise KeyError(f"Unknown operator: {name}") def get_num_ops() -> int: """获取操作符数量""" return len(OPS_CONFIG) ================================================ FILE: backend/app/alpha_mining/dsl/vocab.py ================================================ """ 因子词汇表管理 管理因子表达式中的 token 词汇表,包括: - 特征 token(RET, VOL, VOLUME_CHG 等) - 操作符 token(ADD, SUB, MUL 等) 提供 token <-> name 双向映射。 References: - AlphaGPT upstream/model_core/alphagpt.py:10-14 """ from typing import List, Dict, Optional from dataclasses import dataclass, field from .ops import OPS_CONFIG, get_op_names # 默认特征列表 FEATURES: List[str] = [ "RET", # 收益率 "VOL", # 波动率 "VOLUME_CHG", # 成交量变化 "TURNOVER", # 换手率 "SENTIMENT", # 情感分数 "NEWS_COUNT", # 新闻数量 ] @dataclass class FactorVocab: """ 因子词汇表 词汇表结构:[FEATURES..., OPERATORS...] - 前 num_features 个 token 是特征 - 后 num_ops 个 token 是操作符 Example: vocab = FactorVocab(features=["RET", "VOL"]) vocab.token_to_name(0) # -> "RET" vocab.name_to_token("ADD") # -> 2 (假设有 2 个特征) """ features: List[str] = field(default_factory=lambda: FEATURES.copy()) def __post_init__(self): """初始化词汇表映射""" self._operators = get_op_names() self._vocab = self.features + self._operators # 构建映射 self._token_to_name: Dict[int, str] = { i: name for i, name in enumerate(self._vocab) } self._name_to_token: Dict[str, int] = { name: i for i, name in enumerate(self._vocab) } @property def vocab_size(self) -> int: """词汇表大小""" return len(self._vocab) @property def num_features(self) -> int: """特征数量""" return len(self.features) @property def num_ops(self) -> int: """操作符数量""" return len(self._operators) @property def feature_offset(self) -> int: """特征 token 的结束位置(也是操作符的起始位置)""" return self.num_features def token_to_name(self, token: int) -> str: """ 将 token ID 转换为名称 Args: token: token ID Returns: token 对应的名称 Raises: KeyError: 如果 token 不存在 """ if token not in self._token_to_name: raise KeyError(f"Unknown token: {token}") return self._token_to_name[token] def name_to_token(self, name: str) -> int: """ 将名称转换为 token ID Args: name: 特征或操作符名称 Returns: 对应的 token ID Raises: KeyError: 如果名称不存在 """ if name not in self._name_to_token: raise KeyError(f"Unknown name: {name}") return self._name_to_token[name] def is_feature(self, token: int) -> bool: """判断 token 是否为特征""" return 0 <= token < self.feature_offset def is_operator(self, token: int) -> bool: """判断 token 是否为操作符""" return self.feature_offset <= token < self.vocab_size def get_operator_arity(self, token: int) -> int: """ 获取操作符的参数数量 Args: token: 操作符 token ID Returns: 参数数量(1, 2 或 3) Raises: ValueError: 如果不是操作符 """ if not self.is_operator(token): raise ValueError(f"Token {token} is not an operator") op_index = token - self.feature_offset return OPS_CONFIG[op_index][2] def get_operator_func(self, token: int): """ 获取操作符的函数 Args: token: 操作符 token ID Returns: 操作符函数 Raises: ValueError: 如果不是操作符 """ if not self.is_operator(token): raise ValueError(f"Token {token} is not an operator") op_index = token - self.feature_offset return OPS_CONFIG[op_index][1] def get_all_tokens(self) -> List[int]: """获取所有 token ID""" return list(range(self.vocab_size)) def get_feature_tokens(self) -> List[int]: """获取所有特征 token ID""" return list(range(self.num_features)) def get_operator_tokens(self) -> List[int]: """获取所有操作符 token ID""" return list(range(self.feature_offset, self.vocab_size)) def __repr__(self) -> str: return f"FactorVocab(features={self.features}, vocab_size={self.vocab_size})" # 默认词汇表实例 DEFAULT_VOCAB = FactorVocab() ================================================ FILE: backend/app/alpha_mining/features/__init__.py ================================================ """ 特征构建器模块 - MarketFeatureBuilder: 从行情数据构建特征 - SentimentFeatureBuilder: 从新闻情感分析结果构建特征 """ from .market import MarketFeatureBuilder from .sentiment import SentimentFeatureBuilder __all__ = ["MarketFeatureBuilder", "SentimentFeatureBuilder"] ================================================ FILE: backend/app/alpha_mining/features/market.py ================================================ """ 行情特征构建器 从原始行情数据(OHLCV)构建因子挖掘所需的标准化特征。 特征列表: - RET: 收益率 - VOL: 波动率(滚动标准差) - VOLUME_CHG: 成交量变化率 - TURNOVER: 换手率 """ import torch from typing import Dict, List, Optional, Union import pandas as pd import numpy as np import logging from ..config import AlphaMiningConfig, DEFAULT_CONFIG logger = logging.getLogger(__name__) class MarketFeatureBuilder: """ 行情特征构建器 从 OHLCV 数据构建标准化的因子特征。 Args: config: 配置实例 vol_window: 波动率计算窗口 normalize: 是否标准化特征 Example: builder = MarketFeatureBuilder() features = builder.build(ohlcv_df) """ # 支持的特征名称 FEATURE_NAMES = ["RET", "VOL", "VOLUME_CHG", "TURNOVER"] def __init__( self, config: Optional[AlphaMiningConfig] = None, vol_window: int = 20, normalize: bool = True ): self.config = config or DEFAULT_CONFIG self.vol_window = vol_window self.normalize = normalize logger.info( f"MarketFeatureBuilder initialized: " f"vol_window={vol_window}, normalize={normalize}" ) def build( self, data: Union[pd.DataFrame, Dict[str, torch.Tensor]], device: Optional[torch.device] = None ) -> torch.Tensor: """ 从行情数据构建特征张量 Args: data: 行情数据,DataFrame 或张量字典 DataFrame 需包含: close, volume, (可选: turnover, shares) Dict 需包含: close, volume device: 目标设备 Returns: 特征张量 [batch, num_features, time_steps] """ device = device or self.config.torch_device if isinstance(data, pd.DataFrame): return self._build_from_dataframe(data, device) elif isinstance(data, dict): return self._build_from_tensors(data, device) else: raise ValueError(f"Unsupported data type: {type(data)}") def _build_from_dataframe( self, df: pd.DataFrame, device: torch.device ) -> torch.Tensor: """ 从 DataFrame 构建特征 支持两种格式: 1. 单股票:index=date, columns=[close, volume, ...] 2. 多股票:MultiIndex 或 pivot 后的 DataFrame """ # 确保列名小写 df = df.copy() df.columns = [c.lower() for c in df.columns] # 检查必需列 if "close" not in df.columns: raise ValueError("DataFrame must have 'close' column") # 计算各特征 close = torch.tensor(df["close"].values, dtype=torch.float32) # RET: 收益率 ret = self._calc_returns(close) # VOL: 波动率 vol = self._calc_volatility(ret, self.vol_window) # VOLUME_CHG: 成交量变化 if "volume" in df.columns: volume = torch.tensor(df["volume"].values, dtype=torch.float32) volume_chg = self._calc_pct_change(volume) else: volume_chg = torch.zeros_like(ret) # TURNOVER: 换手率 if "turnover" in df.columns: turnover = torch.tensor(df["turnover"].values, dtype=torch.float32) elif "volume" in df.columns and "shares" in df.columns: volume = df["volume"].values shares = df["shares"].values turnover = torch.tensor(volume / (shares + 1e-8), dtype=torch.float32) else: turnover = torch.zeros_like(ret) # Stack features: [num_features, time_steps] features = torch.stack([ret, vol, volume_chg, turnover], dim=0) # 标准化 if self.normalize: features = self._robust_normalize(features) # 添加 batch 维度: [1, num_features, time_steps] features = features.unsqueeze(0).to(device) return features def _build_from_tensors( self, data: Dict[str, torch.Tensor], device: torch.device ) -> torch.Tensor: """ 从张量字典构建特征 Args: data: 包含 close, volume 等张量的字典 每个张量形状为 [batch, time_steps] 或 [time_steps] """ close = data["close"] # 确保是 2D: [batch, time_steps] if close.dim() == 1: close = close.unsqueeze(0) batch_size, time_steps = close.shape # RET ret = self._calc_returns(close) # VOL vol = self._calc_volatility(ret, self.vol_window) # VOLUME_CHG if "volume" in data: volume = data["volume"] if volume.dim() == 1: volume = volume.unsqueeze(0) volume_chg = self._calc_pct_change(volume) else: volume_chg = torch.zeros_like(ret) # TURNOVER if "turnover" in data: turnover = data["turnover"] if turnover.dim() == 1: turnover = turnover.unsqueeze(0) else: turnover = torch.zeros_like(ret) # Stack: [batch, num_features, time_steps] features = torch.stack([ret, vol, volume_chg, turnover], dim=1) # 标准化 if self.normalize: features = self._robust_normalize(features) return features.to(device) def _calc_returns(self, close: torch.Tensor) -> torch.Tensor: """计算收益率""" # close: [batch, time] or [time] if close.dim() == 1: close = close.unsqueeze(0) prev_close = torch.roll(close, 1, dims=-1) prev_close[..., 0] = close[..., 0] returns = (close - prev_close) / (prev_close + 1e-8) returns[..., 0] = 0 # 第一个收益率设为 0 return returns.squeeze(0) if close.size(0) == 1 else returns def _calc_volatility(self, returns: torch.Tensor, window: int) -> torch.Tensor: """计算滚动波动率""" if returns.dim() == 1: returns = returns.unsqueeze(0) batch_size, time_steps = returns.shape # Padding pad = torch.zeros((batch_size, window - 1), device=returns.device) padded = torch.cat([pad, returns], dim=-1) # 滚动标准差 vol = padded.unfold(-1, window, 1).std(dim=-1) return vol.squeeze(0) if batch_size == 1 else vol def _calc_pct_change(self, x: torch.Tensor) -> torch.Tensor: """计算百分比变化""" if x.dim() == 1: x = x.unsqueeze(0) prev = torch.roll(x, 1, dims=-1) prev[..., 0] = x[..., 0] pct = (x - prev) / (prev + 1e-8) pct[..., 0] = 0 return pct.squeeze(0) if x.size(0) == 1 else pct def _robust_normalize(self, features: torch.Tensor) -> torch.Tensor: """ 稳健标准化(使用中位数和 MAD) Args: features: [batch, num_features, time_steps] 或 [num_features, time_steps] """ if features.dim() == 2: features = features.unsqueeze(0) # 计算每个特征的中位数 median = features.median(dim=-1, keepdim=True).values # 计算 MAD mad = (features - median).abs().median(dim=-1, keepdim=True).values + 1e-6 # 标准化 normalized = (features - median) / mad # 裁剪极端值 normalized = torch.clamp(normalized, -5.0, 5.0) return normalized def get_feature_names(self) -> List[str]: """获取特征名称列表""" return self.FEATURE_NAMES.copy() def build_batch( self, data_list: List[Union[pd.DataFrame, Dict[str, torch.Tensor]]], device: Optional[torch.device] = None ) -> torch.Tensor: """ 批量构建特征 Args: data_list: 行情数据列表 device: 目标设备 Returns: 特征张量 [batch, num_features, time_steps] """ features_list = [] for data in data_list: features = self.build(data, device) features_list.append(features) return torch.cat(features_list, dim=0) ================================================ FILE: backend/app/alpha_mining/features/sentiment.py ================================================ """ 情感特征构建器 从 FinnewsHunter 的新闻分析结果构建情感特征。 特征列表: - SENTIMENT: 情感分数(-1 到 1) - NEWS_COUNT: 新闻数量(标准化) 与 FinnewsHunter 现有组件集成: - 使用 SentimentAgent 的分析结果 - 从 PostgreSQL/Milvus 获取历史情感数据 """ import torch from typing import Dict, List, Optional, Union, Any import pandas as pd import numpy as np import logging from datetime import datetime, timedelta from ..config import AlphaMiningConfig, DEFAULT_CONFIG logger = logging.getLogger(__name__) class SentimentFeatureBuilder: """ 情感特征构建器 从新闻情感分析结果构建因子特征。 Args: config: 配置实例 sentiment_decay: 情感衰减因子(用于时序平滑) normalize: 是否标准化特征 Example: builder = SentimentFeatureBuilder() features = builder.build(sentiment_data) """ # 支持的特征名称 FEATURE_NAMES = ["SENTIMENT", "NEWS_COUNT"] def __init__( self, config: Optional[AlphaMiningConfig] = None, sentiment_decay: float = 0.9, normalize: bool = True ): self.config = config or DEFAULT_CONFIG self.sentiment_decay = sentiment_decay self.normalize = normalize logger.info( f"SentimentFeatureBuilder initialized: " f"decay={sentiment_decay}, normalize={normalize}" ) def build( self, data: Union[pd.DataFrame, Dict[str, Any], List[Dict]], time_steps: Optional[int] = None, device: Optional[torch.device] = None ) -> torch.Tensor: """ 从情感数据构建特征张量 Args: data: 情感数据,支持多种格式: - DataFrame: columns=[date, sentiment, news_count] - Dict: {"sentiment": [...], "news_count": [...]} - List[Dict]: [{"date": ..., "sentiment": ..., "count": ...}, ...] time_steps: 目标时间步数(用于对齐行情数据) device: 目标设备 Returns: 特征张量 [1, 2, time_steps] (SENTIMENT, NEWS_COUNT) """ device = device or self.config.torch_device if isinstance(data, pd.DataFrame): sentiment, news_count = self._parse_dataframe(data) elif isinstance(data, dict): sentiment, news_count = self._parse_dict(data) elif isinstance(data, list): sentiment, news_count = self._parse_list(data) else: raise ValueError(f"Unsupported data type: {type(data)}") # 转换为张量 sentiment = torch.tensor(sentiment, dtype=torch.float32) news_count = torch.tensor(news_count, dtype=torch.float32) # 对齐时间步 if time_steps is not None: sentiment = self._align_time_steps(sentiment, time_steps) news_count = self._align_time_steps(news_count, time_steps) # 应用情感衰减(指数平滑) sentiment = self._apply_decay(sentiment) # Stack: [2, time_steps] features = torch.stack([sentiment, news_count], dim=0) # 标准化 if self.normalize: features = self._normalize(features) # 添加 batch 维度: [1, 2, time_steps] features = features.unsqueeze(0).to(device) return features def _parse_dataframe(self, df: pd.DataFrame): """从 DataFrame 解析情感数据""" df = df.copy() df.columns = [c.lower() for c in df.columns] # 情感分数 if "sentiment" in df.columns: sentiment = df["sentiment"].fillna(0).values elif "sentiment_score" in df.columns: sentiment = df["sentiment_score"].fillna(0).values else: sentiment = np.zeros(len(df)) logger.warning("No sentiment column found, using zeros") # 新闻数量 if "news_count" in df.columns: news_count = df["news_count"].fillna(0).values elif "count" in df.columns: news_count = df["count"].fillna(0).values else: news_count = np.ones(len(df)) # 默认每天 1 条 return sentiment, news_count def _parse_dict(self, data: Dict[str, Any]): """从字典解析情感数据""" sentiment = data.get("sentiment", data.get("sentiment_score", [])) news_count = data.get("news_count", data.get("count", [])) sentiment = np.array(sentiment) if sentiment else np.array([0]) news_count = np.array(news_count) if news_count else np.array([1]) return sentiment, news_count def _parse_list(self, data: List[Dict]): """从列表解析情感数据""" sentiment = [] news_count = [] for item in data: s = item.get("sentiment", item.get("sentiment_score", 0)) c = item.get("news_count", item.get("count", 1)) sentiment.append(s) news_count.append(c) return np.array(sentiment), np.array(news_count) def _align_time_steps(self, x: torch.Tensor, target_len: int) -> torch.Tensor: """对齐时间步长度""" current_len = len(x) if current_len == target_len: return x elif current_len > target_len: # 截取最近的数据 return x[-target_len:] else: # 前面填充 0 pad = torch.zeros(target_len - current_len) return torch.cat([pad, x]) def _apply_decay(self, sentiment: torch.Tensor) -> torch.Tensor: """ 应用指数衰减平滑 情感影响会随时间衰减,使用指数移动平均来平滑 """ if self.sentiment_decay >= 1.0: return sentiment result = torch.zeros_like(sentiment) result[0] = sentiment[0] for i in range(1, len(sentiment)): result[i] = self.sentiment_decay * result[i-1] + (1 - self.sentiment_decay) * sentiment[i] return result def _normalize(self, features: torch.Tensor) -> torch.Tensor: """标准化特征""" # features: [2, time_steps] # SENTIMENT: 已经在 [-1, 1] 范围内,保持不变 # NEWS_COUNT: 标准化到 0 均值、1 标准差 news_count = features[1] if news_count.std() > 0: features[1] = (news_count - news_count.mean()) / (news_count.std() + 1e-6) # 裁剪极端值 features = torch.clamp(features, -5.0, 5.0) return features def get_feature_names(self) -> List[str]: """获取特征名称列表""" return self.FEATURE_NAMES.copy() def build_from_finnews( self, stock_code: str, start_date: datetime, end_date: datetime, db_session: Any = None, device: Optional[torch.device] = None ) -> torch.Tensor: """ 从 FinnewsHunter 数据库构建情感特征 Args: stock_code: 股票代码 start_date: 开始日期 end_date: 结束日期 db_session: 数据库会话(可选,用于真实数据) device: 目标设备 Returns: 特征张量 [1, 2, time_steps] """ device = device or self.config.torch_device # 计算交易日数 time_steps = (end_date - start_date).days if db_session is None: # 无数据库连接时返回模拟数据 logger.warning("No db_session provided, returning mock sentiment data") return self._generate_mock_sentiment(time_steps, device) # TODO: 实现真实数据查询 # 查询逻辑示例: # query = """ # SELECT date, AVG(sentiment_score) as sentiment, COUNT(*) as news_count # FROM news_analysis # WHERE stock_code = :code AND date BETWEEN :start AND :end # GROUP BY date # ORDER BY date # """ # results = db_session.execute(query, {...}) logger.info(f"Building sentiment features for {stock_code}") return self._generate_mock_sentiment(time_steps, device) def _generate_mock_sentiment( self, time_steps: int, device: torch.device ) -> torch.Tensor: """生成模拟情感数据""" # 模拟情感分数(正态分布,均值 0) sentiment = torch.randn(time_steps) * 0.3 sentiment = torch.clamp(sentiment, -1, 1) # 模拟新闻数量(泊松分布) news_count = torch.abs(torch.randn(time_steps)) * 3 + 1 # Stack 并添加 batch 维度 features = torch.stack([sentiment, news_count], dim=0) if self.normalize: features = self._normalize(features) return features.unsqueeze(0).to(device) def combine_with_market( self, market_features: torch.Tensor, sentiment_features: torch.Tensor ) -> torch.Tensor: """ 合并行情特征和情感特征 Args: market_features: [batch, 4, time_steps] (RET, VOL, VOLUME_CHG, TURNOVER) sentiment_features: [batch, 2, time_steps] (SENTIMENT, NEWS_COUNT) Returns: 合并后的特征 [batch, 6, time_steps] """ return torch.cat([market_features, sentiment_features], dim=1) ================================================ FILE: backend/app/alpha_mining/model/__init__.py ================================================ """ 因子生成模型和训练器 - AlphaGenerator: Transformer 策略网络,生成因子表达式 - AlphaTrainer: RL 训练器,使用 REINFORCE 算法优化 """ from .alpha_generator import AlphaGenerator from .trainer import AlphaTrainer __all__ = ["AlphaGenerator", "AlphaTrainer"] ================================================ FILE: backend/app/alpha_mining/model/alpha_generator.py ================================================ """ 因子生成模型 基于 Transformer 的策略网络,用于生成因子表达式 token 序列。 架构: - Token Embedding + Position Embedding - Transformer Encoder(使用 causal mask) - Policy Head(输出 token 概率) - Value Head(估计状态价值,用于 Actor-Critic) References: - AlphaGPT upstream/model_core/alphagpt.py """ import torch import torch.nn as nn from torch.distributions import Categorical from typing import Tuple, List, Optional import logging from ..config import AlphaMiningConfig, DEFAULT_CONFIG from ..dsl.vocab import FactorVocab, DEFAULT_VOCAB logger = logging.getLogger(__name__) class AlphaGenerator(nn.Module): """ 因子生成器(Transformer 策略网络) 使用 Transformer 架构生成因子表达式的 token 序列。 Args: vocab: 词汇表实例 config: 配置实例 Example: generator = AlphaGenerator() tokens = torch.zeros((batch_size, 1), dtype=torch.long) logits, value = generator(tokens) """ def __init__( self, vocab: Optional[FactorVocab] = None, config: Optional[AlphaMiningConfig] = None ): super().__init__() self.vocab = vocab or DEFAULT_VOCAB self.config = config or DEFAULT_CONFIG # 模型参数 self.vocab_size = self.vocab.vocab_size self.d_model = self.config.d_model self.max_seq_len = self.config.max_seq_len # Token Embedding self.token_emb = nn.Embedding(self.vocab_size, self.d_model) # Position Embedding(可学习的位置编码) self.pos_emb = nn.Parameter( torch.zeros(1, self.max_seq_len + 1, self.d_model) ) # Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=self.config.nhead, dim_feedforward=self.d_model * 2, batch_first=True, dropout=0.1 ) self.transformer = nn.TransformerEncoder( encoder_layer, num_layers=self.config.num_layers ) # Output heads self.ln_f = nn.LayerNorm(self.d_model) self.policy_head = nn.Linear(self.d_model, self.vocab_size) # Actor self.value_head = nn.Linear(self.d_model, 1) # Critic # 初始化权重 self._init_weights() logger.info( f"AlphaGenerator initialized: vocab_size={self.vocab_size}, " f"d_model={self.d_model}, max_seq_len={self.max_seq_len}" ) def _init_weights(self): """初始化模型权重""" # 使用 Xavier 初始化 for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, tokens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ 前向传播 Args: tokens: 输入 token 序列 [batch, seq_len] Returns: logits: 下一个 token 的 logits [batch, vocab_size] value: 状态价值估计 [batch, 1] """ batch_size, seq_len = tokens.size() device = tokens.device # Token + Position Embedding x = self.token_emb(tokens) + self.pos_emb[:, :seq_len, :] # Causal Mask(确保只能看到之前的 token) mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device) # Transformer 编码 x = self.transformer(x, mask=mask, is_causal=True) # Layer Norm x = self.ln_f(x) # 取最后一个位置的表示 last_hidden = x[:, -1, :] # [batch, d_model] # 输出 heads logits = self.policy_head(last_hidden) # [batch, vocab_size] value = self.value_head(last_hidden) # [batch, 1] return logits, value @torch.no_grad() def generate( self, batch_size: int = 1, max_len: Optional[int] = None, temperature: float = 1.0, device: Optional[torch.device] = None ) -> Tuple[List[List[int]], List[torch.Tensor]]: """ 批量生成因子表达式 使用自回归采样生成 token 序列。 Args: batch_size: 生成数量 max_len: 最大长度,默认使用 config.max_seq_len temperature: 采样温度,越高越随机 device: 设备,默认使用 config.device Returns: formulas: 生成的 token 序列列表 log_probs_list: 每个序列的 log_prob 列表(用于策略梯度) """ self.eval() max_len = max_len or self.config.max_seq_len device = device or self.config.torch_device # 初始化:以空 token 开始(使用 0) tokens = torch.zeros((batch_size, 1), dtype=torch.long, device=device) all_log_probs: List[List[torch.Tensor]] = [[] for _ in range(batch_size)] for step in range(max_len): # 前向传播 logits, _ = self.forward(tokens) # 应用温度 if temperature != 1.0: logits = logits / temperature # 采样 dist = Categorical(logits=logits) action = dist.sample() # [batch] # 记录 log_prob log_prob = dist.log_prob(action) # [batch] for i in range(batch_size): all_log_probs[i].append(log_prob[i]) # 拼接到序列 tokens = torch.cat([tokens, action.unsqueeze(1)], dim=1) # 转换为列表格式 formulas = tokens[:, 1:].tolist() # 去掉初始的 0 # 将 log_probs 转换为 tensor 列表 log_probs_tensors = [torch.stack(lps) for lps in all_log_probs] return formulas, log_probs_tensors def generate_with_training( self, batch_size: int = 1, max_len: Optional[int] = None, device: Optional[torch.device] = None ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: """ 生成因子表达式(训练模式,保留梯度) Args: batch_size: 生成数量 max_len: 最大长度 device: 设备 Returns: sequences: 生成的序列 [batch, seq_len] log_probs: 每步的 log_prob 列表 values: 每步的 value 估计列表 """ self.train() max_len = max_len or self.config.max_seq_len device = device or self.config.torch_device # 初始化 tokens = torch.zeros((batch_size, 1), dtype=torch.long, device=device) log_probs_list = [] values_list = [] tokens_list = [] for step in range(max_len): # 前向传播 logits, value = self.forward(tokens) # 采样 dist = Categorical(logits=logits) action = dist.sample() # 记录 log_probs_list.append(dist.log_prob(action)) values_list.append(value.squeeze(-1)) tokens_list.append(action) # 拼接 tokens = torch.cat([tokens, action.unsqueeze(1)], dim=1) # 组装结果 sequences = torch.stack(tokens_list, dim=1) # [batch, max_len] return sequences, log_probs_list, values_list def save(self, path: str): """保存模型""" torch.save({ 'model_state_dict': self.state_dict(), 'vocab_size': self.vocab_size, 'd_model': self.d_model, 'max_seq_len': self.max_seq_len, }, path) logger.info(f"Model saved to {path}") @classmethod def load(cls, path: str, vocab: Optional[FactorVocab] = None) -> 'AlphaGenerator': """加载模型""" checkpoint = torch.load(path, map_location='cpu') # 创建模型 config = AlphaMiningConfig( d_model=checkpoint['d_model'], max_seq_len=checkpoint['max_seq_len'] ) model = cls(vocab=vocab, config=config) # 加载权重 model.load_state_dict(checkpoint['model_state_dict']) logger.info(f"Model loaded from {path}") return model ================================================ FILE: backend/app/alpha_mining/model/trainer.py ================================================ """ 因子挖掘 RL 训练器 使用 REINFORCE 算法训练 AlphaGenerator,以回测收益为奖励信号。 训练流程: 1. 生成因子表达式 2. 执行表达式得到因子值 3. 回测评估因子有效性(计算奖励) 4. 策略梯度更新 References: - AlphaGPT upstream/model_core/engine.py """ import torch from typing import Optional, List, Dict, Any, Callable from tqdm import tqdm import logging import json from pathlib import Path from ..config import AlphaMiningConfig, DEFAULT_CONFIG from ..dsl.vocab import FactorVocab, DEFAULT_VOCAB from ..vm.factor_vm import FactorVM from .alpha_generator import AlphaGenerator logger = logging.getLogger(__name__) class AlphaTrainer: """ 因子挖掘 RL 训练器 使用 REINFORCE 算法训练 AlphaGenerator。 Args: generator: 因子生成模型 vocab: 词汇表 config: 配置 evaluator: 因子评估函数,接收 (factor, returns) 返回 score """ def __init__( self, generator: Optional[AlphaGenerator] = None, vocab: Optional[FactorVocab] = None, config: Optional[AlphaMiningConfig] = None, evaluator: Optional[Callable[[torch.Tensor, torch.Tensor], float]] = None ): self.config = config or DEFAULT_CONFIG self.vocab = vocab or DEFAULT_VOCAB self.generator = generator or AlphaGenerator(vocab=self.vocab, config=self.config) self.vm = FactorVM(vocab=self.vocab) # 默认评估器(简单 Sharpe-like) self.evaluator = evaluator or self._default_evaluator # 优化器 self.optimizer = torch.optim.AdamW( self.generator.parameters(), lr=self.config.lr ) # 训练状态 self.best_score = -float('inf') self.best_formula: Optional[List[int]] = None self.best_formula_str: Optional[str] = None self.training_history: List[Dict[str, Any]] = [] self.step_count = 0 # 移动到指定设备 self.device = self.config.torch_device self.generator.to(self.device) logger.info(f"AlphaTrainer initialized on device: {self.device}") def _default_evaluator(self, factor: torch.Tensor, returns: torch.Tensor) -> float: """ 默认因子评估器(简化版 Sharpe-like) Args: factor: 因子值 [batch, time_steps] returns: 收益率 [batch, time_steps] Returns: 评分(越高越好) """ # 因子值作为信号(sigmoid 归一化) signal = torch.sigmoid(factor) # 简单策略:signal > threshold 时持仓 threshold = self.config.signal_threshold position = (signal > threshold).float() # 计算收益 pnl = position * returns # Sharpe-like ratio(简化) mean_pnl = pnl.mean() std_pnl = pnl.std() + 1e-6 score = (mean_pnl / std_pnl).item() return score def train_step( self, features: torch.Tensor, returns: torch.Tensor ) -> Dict[str, Any]: """ 单步训练 Args: features: 特征张量 [batch, num_features, time_steps] returns: 收益率张量 [batch, time_steps] Returns: 训练指标字典 """ self.generator.train() batch_size = self.config.batch_size # 1. 生成因子表达式 sequences, log_probs_list, _ = self.generator.generate_with_training( batch_size=batch_size, device=self.device ) # 2. 执行并评估每个公式 rewards = torch.zeros(batch_size, device=self.device) valid_count = 0 for i in range(batch_size): formula = sequences[i].tolist() # 执行因子表达式 factor = self.vm.execute(formula, features) if factor is None: # 无效公式 rewards[i] = self.config.invalid_formula_reward continue # 检查是否为常量因子 if factor.std() < self.config.constant_threshold: rewards[i] = self.config.constant_factor_reward continue # 评估因子 try: score = self.evaluator(factor, returns) rewards[i] = score valid_count += 1 # 更新最优 if score > self.best_score: self.best_score = score self.best_formula = formula self.best_formula_str = self.vm.decode(formula) logger.info( f"[Step {self.step_count}] New best: " f"score={score:.4f}, formula={self.best_formula_str}" ) except Exception as e: logger.warning(f"Evaluation error: {e}") rewards[i] = self.config.invalid_formula_reward # 3. 计算 advantage(归一化) adv = (rewards - rewards.mean()) / (rewards.std() + 1e-5) # 4. 策略梯度 loss loss = torch.zeros(1, device=self.device) for t, log_prob in enumerate(log_probs_list): loss = loss - (log_prob * adv).mean() # 5. 反向传播 self.optimizer.zero_grad() loss.backward() # 梯度裁剪 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0) self.optimizer.step() # 6. 记录指标 self.step_count += 1 metrics = { "step": self.step_count, "loss": loss.item(), "avg_reward": rewards.mean().item(), "max_reward": rewards.max().item(), "min_reward": rewards.min().item(), "valid_ratio": valid_count / batch_size, "best_score": self.best_score, "best_formula": self.best_formula_str, } self.training_history.append(metrics) return metrics def train( self, features: torch.Tensor, returns: torch.Tensor, num_steps: Optional[int] = None, progress_bar: bool = True, step_callback: Optional[Callable[[Dict[str, Any]], None]] = None ) -> Dict[str, Any]: """ 完整训练循环 Args: features: 特征张量 [num_samples, num_features, time_steps] returns: 收益率张量 [num_samples, time_steps] num_steps: 训练步数,默认使用 config.num_steps progress_bar: 是否显示进度条 step_callback: 每步回调函数,接收 metrics 字典,用于 SSE 流式推送 Returns: 训练结果 """ num_steps = num_steps or self.config.num_steps logger.info(f"Starting training for {num_steps} steps...") # 确保数据在正确设备上 features = features.to(self.device) returns = returns.to(self.device) iterator = range(num_steps) if progress_bar: iterator = tqdm(iterator, desc="Training") for step in iterator: metrics = self.train_step(features, returns) # 添加进度百分比 metrics["progress"] = (step + 1) / num_steps * 100 metrics["total_steps"] = num_steps if progress_bar: iterator.set_postfix({ "loss": f"{metrics['loss']:.4f}", "avg_rew": f"{metrics['avg_reward']:.4f}", "best": f"{metrics['best_score']:.4f}" }) # 调用回调函数(用于 SSE 流式推送) if step_callback is not None: try: step_callback(metrics) except Exception as e: logger.warning(f"Step callback error: {e}") # 定期保存检查点 if self.step_count % self.config.save_every_n_steps == 0: self._save_checkpoint() # 最终结果 result = { "total_steps": self.step_count, "best_score": self.best_score, "best_formula": self.best_formula, "best_formula_str": self.best_formula_str, "final_metrics": self.training_history[-1] if self.training_history else None, } logger.info(f"Training complete. Best score: {self.best_score:.4f}") logger.info(f"Best formula: {self.best_formula_str}") return result def _save_checkpoint(self): """保存训练检查点""" checkpoint_dir = Path(self.config.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) # 保存模型 model_path = checkpoint_dir / f"model_step_{self.step_count}.pt" self.generator.save(str(model_path)) # 保存最优公式 if self.best_formula: formula_path = checkpoint_dir / "best_formula.json" with open(formula_path, 'w') as f: json.dump({ "formula": self.best_formula, "formula_str": self.best_formula_str, "score": self.best_score, "step": self.step_count }, f, indent=2) def get_best_formula(self) -> Optional[str]: """获取最优因子表达式字符串""" return self.best_formula_str def get_training_history(self) -> List[Dict[str, Any]]: """获取训练历史""" return self.training_history ================================================ FILE: backend/app/alpha_mining/tools/__init__.py ================================================ """ AgenticX 工具封装 将因子挖掘能力封装为 AgenticX Tool,供 QuantitativeAgent 调用。 """ from .alpha_mining_tool import AlphaMiningTool __all__ = ["AlphaMiningTool"] ================================================ FILE: backend/app/alpha_mining/tools/alpha_mining_tool.py ================================================ """ Alpha Mining AgenticX 工具封装 将因子挖掘功能封装为 AgenticX BaseTool,供 Agent 调用。 支持的操作: - mine: 挖掘新因子 - evaluate: 评估现有因子 - list: 列出已发现的因子 """ import torch from typing import Dict, Any, Optional, List from datetime import datetime import logging import json import uuid from agenticx.core.tool_v2 import ( BaseTool, ToolMetadata, ToolParameter, ToolResult, ToolContext, ToolCategory, ToolStatus, ParameterType ) from ..config import AlphaMiningConfig, DEFAULT_CONFIG from ..dsl.vocab import FactorVocab, DEFAULT_VOCAB from ..vm.factor_vm import FactorVM from ..model.alpha_generator import AlphaGenerator from ..model.trainer import AlphaTrainer from ..features.market import MarketFeatureBuilder from ..features.sentiment import SentimentFeatureBuilder from ..backtest.evaluator import FactorEvaluator from ..utils import generate_mock_data logger = logging.getLogger(__name__) class AlphaMiningTool(BaseTool[Dict[str, Any]]): """ Alpha Mining 工具 封装因子挖掘功能,供 QuantitativeAgent 调用。 支持操作: - mine: 使用 RL 挖掘新因子 - evaluate: 评估指定因子表达式 - generate: 生成候选因子 - list: 列出最优因子 Example: tool = AlphaMiningTool() result = tool.execute({ "action": "mine", "num_steps": 100, "use_sentiment": True }, context) """ def __init__( self, config: Optional[AlphaMiningConfig] = None, model_path: Optional[str] = None ): """ 初始化 Alpha Mining 工具 Args: config: 配置实例 model_path: 预训练模型路径 """ self.config = config or DEFAULT_CONFIG metadata = ToolMetadata( name="alpha_mining", version="1.0.0", description="量化因子自动挖掘工具,使用符号回归 + 强化学习发现有效交易因子", category=ToolCategory.ANALYSIS, author="FinnewsHunter Team", tags=["quant", "factor", "alpha", "ml", "reinforcement-learning"], timeout=600, # 10分钟超时 max_retries=1, ) super().__init__(metadata) # 初始化组件 self.vocab = DEFAULT_VOCAB self.vm = FactorVM(vocab=self.vocab) self.evaluator = FactorEvaluator(config=self.config) self.market_builder = MarketFeatureBuilder(config=self.config) self.sentiment_builder = SentimentFeatureBuilder(config=self.config) # 初始化模型 self.generator = AlphaGenerator(vocab=self.vocab, config=self.config) self.trainer: Optional[AlphaTrainer] = None # 加载预训练模型 if model_path: try: self.generator = AlphaGenerator.load(model_path, vocab=self.vocab) logger.info(f"Loaded pretrained model from {model_path}") except Exception as e: logger.warning(f"Failed to load model: {e}") # 存储发现的因子 self.discovered_factors: List[Dict[str, Any]] = [] logger.info("AlphaMiningTool initialized") def _setup_parameters(self) -> None: """设置工具参数""" self._parameters = { "action": ToolParameter( name="action", type=ParameterType.STRING, description="操作类型: mine(挖掘), evaluate(评估), generate(生成), list(列表)", required=True, enum=["mine", "evaluate", "generate", "list"] ), "num_steps": ToolParameter( name="num_steps", type=ParameterType.INTEGER, description="训练步数(仅 mine 操作)", required=False, default=100, minimum=1, maximum=10000 ), "formula": ToolParameter( name="formula", type=ParameterType.STRING, description="因子表达式(仅 evaluate 操作)", required=False ), "use_sentiment": ToolParameter( name="use_sentiment", type=ParameterType.BOOLEAN, description="是否使用情感特征", required=False, default=True ), "batch_size": ToolParameter( name="batch_size", type=ParameterType.INTEGER, description="生成因子数量(仅 generate 操作)", required=False, default=10, minimum=1, maximum=100 ), "top_k": ToolParameter( name="top_k", type=ParameterType.INTEGER, description="返回最优因子数量(仅 list 操作)", required=False, default=5, minimum=1, maximum=50 ), "market_data": ToolParameter( name="market_data", type=ParameterType.OBJECT, description="行情数据(可选,不提供则使用模拟数据)", required=False ), "sentiment_data": ToolParameter( name="sentiment_data", type=ParameterType.OBJECT, description="情感数据(可选)", required=False ) } def execute(self, parameters: Dict[str, Any], context: ToolContext) -> ToolResult: """同步执行工具""" start_time = datetime.now() try: validated = self.validate_parameters(parameters) action = validated["action"] if action == "mine": result_data = self._action_mine(validated, context) elif action == "evaluate": result_data = self._action_evaluate(validated, context) elif action == "generate": result_data = self._action_generate(validated, context) elif action == "list": result_data = self._action_list(validated, context) else: raise ValueError(f"Unknown action: {action}") end_time = datetime.now() return ToolResult( status=ToolStatus.SUCCESS, data=result_data, execution_time=(end_time - start_time).total_seconds(), start_time=start_time, end_time=end_time, metadata={"action": action} ) except Exception as e: logger.error(f"AlphaMiningTool error: {e}") end_time = datetime.now() return ToolResult( status=ToolStatus.ERROR, error=str(e), execution_time=(end_time - start_time).total_seconds(), start_time=start_time, end_time=end_time ) async def aexecute(self, parameters: Dict[str, Any], context: ToolContext) -> ToolResult: """异步执行工具""" # 目前使用同步实现 return self.execute(parameters, context) def _action_mine(self, params: Dict[str, Any], context: ToolContext) -> Dict[str, Any]: """执行因子挖掘""" num_steps = params.get("num_steps", 100) use_sentiment = params.get("use_sentiment", True) # 准备特征数据 features, returns = self._prepare_features(params, use_sentiment) # 创建或复用训练器 if self.trainer is None: self.trainer = AlphaTrainer( generator=self.generator, vocab=self.vocab, config=self.config, evaluator=self.evaluator.get_reward ) # 执行训练 logger.info(f"Starting factor mining for {num_steps} steps...") result = self.trainer.train( features=features, returns=returns, num_steps=num_steps, progress_bar=False ) # 保存最优因子 if result["best_formula"]: factor_info = { "id": str(uuid.uuid4()), "formula": result["best_formula"], "formula_str": result["best_formula_str"], "score": result["best_score"], "discovered_at": datetime.now().isoformat(), "training_steps": num_steps, "use_sentiment": use_sentiment } self.discovered_factors.append(factor_info) # 保持只存储最优的 100 个 self.discovered_factors.sort(key=lambda x: x["score"], reverse=True) self.discovered_factors = self.discovered_factors[:100] return { "success": True, "best_factor": result["best_formula_str"], "best_score": result["best_score"], "total_steps": result["total_steps"], "message": f"因子挖掘完成,最优因子: {result['best_formula_str']} (score={result['best_score']:.4f})" } def _action_evaluate(self, params: Dict[str, Any], context: ToolContext) -> Dict[str, Any]: """评估因子表达式""" formula_str = params.get("formula") if not formula_str: raise ValueError("Parameter 'formula' is required for evaluate action") use_sentiment = params.get("use_sentiment", True) # 解析公式 formula = self._parse_formula(formula_str) if formula is None: return { "success": False, "error": f"Invalid formula: {formula_str}", "message": "无法解析因子表达式" } # 准备数据 features, returns = self._prepare_features(params, use_sentiment) # 执行因子 factor = self.vm.execute(formula, features) if factor is None: return { "success": False, "error": "Formula execution failed", "message": "因子表达式执行失败" } # 评估 metrics = self.evaluator.evaluate(factor, returns) return { "success": True, "formula": formula_str, "metrics": metrics, "message": f"因子评估完成: Sortino={metrics['sortino_ratio']:.4f}, IC={metrics['ic']:.4f}" } def _action_generate(self, params: Dict[str, Any], context: ToolContext) -> Dict[str, Any]: """生成候选因子""" batch_size = params.get("batch_size", 10) use_sentiment = params.get("use_sentiment", True) # 生成因子 formulas, _ = self.generator.generate(batch_size=batch_size) # 准备数据用于评估 features, returns = self._prepare_features(params, use_sentiment) # 评估每个因子 results = [] for formula in formulas: factor = self.vm.execute(formula, features) if factor is not None and factor.std() > 1e-6: try: metrics = self.evaluator.evaluate(factor, returns) results.append({ "formula": formula, "formula_str": self.vm.decode(formula), "sortino": metrics["sortino_ratio"], "ic": metrics["ic"] }) except Exception: continue # 按 Sortino 排序 results.sort(key=lambda x: x["sortino"], reverse=True) return { "success": True, "generated": len(formulas), "valid": len(results), "factors": results[:10], # 返回 top 10 "message": f"生成 {len(formulas)} 个因子,其中 {len(results)} 个有效" } def _action_list(self, params: Dict[str, Any], context: ToolContext) -> Dict[str, Any]: """列出已发现的因子""" top_k = params.get("top_k", 5) factors = self.discovered_factors[:top_k] return { "success": True, "total_discovered": len(self.discovered_factors), "factors": factors, "message": f"共发现 {len(self.discovered_factors)} 个因子,返回 top {len(factors)}" } def _prepare_features( self, params: Dict[str, Any], use_sentiment: bool ) -> tuple: """准备特征数据""" market_data = params.get("market_data") sentiment_data = params.get("sentiment_data") if market_data is not None: # 使用提供的行情数据 market_features = self.market_builder.build(market_data) time_steps = market_features.size(-1) if use_sentiment and sentiment_data is not None: sentiment_features = self.sentiment_builder.build( sentiment_data, time_steps=time_steps ) features = self.sentiment_builder.combine_with_market( market_features, sentiment_features ) else: features = market_features # 假设收益率在行情数据中 returns = market_features[:, 0, :] # RET 特征 else: # 使用模拟数据 num_features = 6 if use_sentiment else 4 features, returns = generate_mock_data( num_samples=50, num_features=num_features, time_steps=252, seed=42 ) return features, returns def _parse_formula(self, formula_str: str) -> Optional[List[int]]: """解析因子表达式字符串""" # 简单解析:尝试匹配已知的 token tokens = [] # 移除括号和空格,按操作符分割 clean = formula_str.replace("(", " ").replace(")", " ").replace(",", " ") parts = clean.split() for part in parts: part = part.strip() if not part: continue # 尝试作为特征名 try: token = self.vocab.name_to_token(part) tokens.append(token) except (ValueError, KeyError): # 尝试作为数字(常量) try: float(part) # 忽略常量 continue except ValueError: logger.warning(f"Unknown token: {part}") return None return tokens if tokens else None ================================================ FILE: backend/app/alpha_mining/utils.py ================================================ """ Alpha Mining 工具函数 提供模拟数据生成、数据预处理等工具函数。 """ import torch import numpy as np from typing import Tuple, Optional import logging from .config import AlphaMiningConfig, DEFAULT_CONFIG logger = logging.getLogger(__name__) def generate_mock_data( num_samples: int = 100, num_features: int = 6, time_steps: int = 252, seed: Optional[int] = 42, device: Optional[torch.device] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ 生成模拟行情数据用于测试 Args: num_samples: 样本数(股票数) num_features: 特征数 time_steps: 时间步数(交易日数) seed: 随机种子 device: 设备 Returns: features: [num_samples, num_features, time_steps] returns: [num_samples, time_steps] """ if seed is not None: torch.manual_seed(seed) np.random.seed(seed) device = device or DEFAULT_CONFIG.torch_device # 生成模拟收益率(正态分布) returns = torch.randn(num_samples, time_steps, device=device) * 0.02 # 生成模拟价格(累积收益) prices = torch.exp(returns.cumsum(dim=1)) # 生成模拟特征 features_list = [] # Feature 0: RET - 收益率 ret = returns.clone() features_list.append(ret) # Feature 1: VOL - 波动率(滚动 20 日标准差) vol = _rolling_std(returns, window=20) features_list.append(vol) # Feature 2: VOLUME_CHG - 成交量变化(模拟) volume = torch.abs(torch.randn(num_samples, time_steps, device=device)) volume_chg = _pct_change(volume) features_list.append(volume_chg) # Feature 3: TURNOVER - 换手率(模拟) turnover = torch.abs(torch.randn(num_samples, time_steps, device=device)) * 0.05 features_list.append(turnover) # Feature 4: SENTIMENT - 情感分数(模拟) sentiment = torch.randn(num_samples, time_steps, device=device) * 0.5 features_list.append(sentiment) # Feature 5: NEWS_COUNT - 新闻数量(模拟) news_count = torch.abs(torch.randn(num_samples, time_steps, device=device)) * 5 features_list.append(news_count) # 如果需要更多特征,填充随机噪声 while len(features_list) < num_features: noise = torch.randn(num_samples, time_steps, device=device) features_list.append(noise) # 截取到指定特征数 features_list = features_list[:num_features] # Stack features: [num_samples, num_features, time_steps] features = torch.stack(features_list, dim=1) # 标准化特征 features = _robust_normalize(features) logger.debug( f"Generated mock data: features {features.shape}, returns {returns.shape}" ) return features, returns def _rolling_std(x: torch.Tensor, window: int = 20) -> torch.Tensor: """ 计算滚动标准差 Args: x: [batch, time_steps] window: 窗口大小 Returns: 滚动标准差 [batch, time_steps] """ batch_size, time_steps = x.shape device = x.device # Padding pad = torch.zeros((batch_size, window - 1), device=device) x_padded = torch.cat([pad, x], dim=1) # 使用 unfold 计算滚动窗口 result = x_padded.unfold(1, window, 1).std(dim=-1) return result def _pct_change(x: torch.Tensor) -> torch.Tensor: """ 计算百分比变化 Args: x: [batch, time_steps] Returns: 百分比变化 [batch, time_steps] """ prev = torch.roll(x, 1, dims=1) prev[:, 0] = x[:, 0] # 第一个值不变 pct = (x - prev) / (prev + 1e-8) return pct def _robust_normalize(x: torch.Tensor) -> torch.Tensor: """ 稳健标准化(使用中位数和 MAD) Args: x: [batch, num_features, time_steps] Returns: 标准化后的张量 """ # 计算每个特征的中位数 median = x.median(dim=2, keepdim=True).values # 计算 MAD (Median Absolute Deviation) mad = (x - median).abs().median(dim=2, keepdim=True).values + 1e-6 # 标准化 normalized = (x - median) / mad # 裁剪极端值 normalized = torch.clamp(normalized, -5.0, 5.0) return normalized def set_random_seed(seed: int): """设置随机种子以确保可复现性""" torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device() -> torch.device: """获取最佳可用设备""" if torch.cuda.is_available(): return torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") else: return torch.device("cpu") ================================================ FILE: backend/app/alpha_mining/vm/__init__.py ================================================ """ 因子执行器模块 提供 FactorVM 栈式虚拟机,用于执行因子表达式。 """ from .factor_vm import FactorVM __all__ = ["FactorVM"] ================================================ FILE: backend/app/alpha_mining/vm/factor_vm.py ================================================ """ 因子表达式执行器(栈式虚拟机) 使用栈式执行方式解析和执行因子表达式 token 序列。 执行流程: 1. 遍历 token 序列 2. 如果是特征 token:将对应特征数据入栈 3. 如果是操作符 token:弹出所需参数,执行操作,结果入栈 4. 最终栈中应只剩一个结果 References: - AlphaGPT upstream/model_core/vm.py """ import torch from typing import List, Optional, Union import logging from ..dsl.vocab import FactorVocab, DEFAULT_VOCAB logger = logging.getLogger(__name__) class FactorVM: """ 因子表达式栈式虚拟机 执行因子表达式 token 序列,返回计算结果。 Example: vm = FactorVM() # features: [batch, num_features, time_steps] # formula: [0, 1, 6] 表示 ADD(RET, VOL) result = vm.execute([0, 1, 6], features) """ def __init__(self, vocab: Optional[FactorVocab] = None): """ 初始化虚拟机 Args: vocab: 词汇表实例,默认使用 DEFAULT_VOCAB """ self.vocab = vocab or DEFAULT_VOCAB def execute( self, formula: List[int], features: torch.Tensor ) -> Optional[torch.Tensor]: """ 执行因子表达式 Args: formula: token 序列,如 [0, 1, 6] 表示 ADD(RET, VOL) features: 特征张量,形状 [batch, num_features, time_steps] Returns: 因子值张量 [batch, time_steps],如果表达式无效则返回 None Note: - 如果堆栈溢出/不足,返回 None - 如果结果包含 NaN/Inf,会自动替换为 0 - 如果最终堆栈不是恰好一个元素,返回 None """ stack: List[torch.Tensor] = [] try: for token in formula: token = int(token) if self.vocab.is_feature(token): # 特征 token:从特征张量中取出对应特征 if token >= features.shape[1]: logger.debug(f"Feature index {token} out of range") return None stack.append(features[:, token, :]) elif self.vocab.is_operator(token): # 操作符 token:执行操作 arity = self.vocab.get_operator_arity(token) # 检查堆栈是否有足够参数 if len(stack) < arity: logger.debug(f"Stack underflow: need {arity}, have {len(stack)}") return None # 弹出参数(注意顺序:先弹出的是后入的) args = [] for _ in range(arity): args.append(stack.pop()) args.reverse() # 恢复正确顺序 # 执行操作 func = self.vocab.get_operator_func(token) result = func(*args) # 处理 NaN 和 Inf if torch.isnan(result).any() or torch.isinf(result).any(): result = torch.nan_to_num( result, nan=0.0, posinf=1.0, neginf=-1.0 ) stack.append(result) else: # 未知 token logger.debug(f"Unknown token: {token}") return None # 检查最终堆栈状态 if len(stack) == 1: return stack[0] else: logger.debug(f"Invalid stack state: {len(stack)} elements remaining") return None except Exception as e: logger.debug(f"Execution error: {e}") return None def decode(self, formula: List[int]) -> str: """ 将 token 序列解码为人类可读的表达式字符串 使用逆波兰表达式解析,转换为前缀表示法(函数调用形式) Args: formula: token 序列 Returns: 人类可读的表达式,如 "ADD(RET, VOL)" Example: vm.decode([0, 1, 6]) # -> "ADD(RET, VOL)" vm.decode([0, 4]) # -> "NEG(RET)" """ stack: List[str] = [] try: for token in formula: token = int(token) if self.vocab.is_feature(token): # 特征:直接入栈名称 name = self.vocab.token_to_name(token) stack.append(name) elif self.vocab.is_operator(token): # 操作符:弹出参数,构建表达式 name = self.vocab.token_to_name(token) arity = self.vocab.get_operator_arity(token) if len(stack) < arity: return f"" args = [] for _ in range(arity): args.append(stack.pop()) args.reverse() # 构建函数调用形式 expr = f"{name}({', '.join(args)})" stack.append(expr) else: return f"" if len(stack) == 1: return stack[0] elif len(stack) == 0: return "" else: # 多个元素:用逗号连接 return f"" except Exception as e: return f"" def validate(self, formula: List[int]) -> bool: """ 验证因子表达式是否语法正确 使用模拟执行(不实际计算)来验证。 Args: formula: token 序列 Returns: True 如果表达式语法正确 """ stack_depth = 0 try: for token in formula: token = int(token) if self.vocab.is_feature(token): stack_depth += 1 elif self.vocab.is_operator(token): arity = self.vocab.get_operator_arity(token) if stack_depth < arity: return False stack_depth -= arity stack_depth += 1 # 操作结果 else: return False return stack_depth == 1 except Exception: return False def get_required_features(self, formula: List[int]) -> List[int]: """ 获取表达式中使用的特征列表 Args: formula: token 序列 Returns: 使用的特征 token 列表(去重) """ features = [] for token in formula: token = int(token) if self.vocab.is_feature(token) and token not in features: features.append(token) return features ================================================ FILE: backend/app/api/__init__.py ================================================ """ API模块 """ ================================================ FILE: backend/app/api/v1/__init__.py ================================================ """ API v1 模块 """ from fastapi import APIRouter from . import analysis, tasks, llm_config, stocks, agents, debug, knowledge_graph from . import news # 原有的新闻 API(数据库操作) from . import news_v2 # 新版 API(Provider-Fetcher 实时获取) from . import alpha_mining # 因子挖掘 API # 创建主路由器 api_router = APIRouter() # 注册子路由 api_router.include_router(news.router, prefix="/news", tags=["news"]) # 原有端点 api_router.include_router(news_v2.router, prefix="/news/v2", tags=["news-v2"]) # 新版端点 api_router.include_router(analysis.router, prefix="/analysis", tags=["analysis"]) api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) api_router.include_router(llm_config.router, prefix="/llm", tags=["llm"]) api_router.include_router(stocks.router, prefix="/stocks", tags=["stocks"]) # Phase 2: 个股分析 api_router.include_router(agents.router, prefix="/agents", tags=["agents"]) # Phase 2: 智能体监控 api_router.include_router(debug.router, prefix="/debug", tags=["debug"]) # 调试工具 api_router.include_router(knowledge_graph.router, prefix="/knowledge-graph", tags=["knowledge-graph"]) # 知识图谱 api_router.include_router(alpha_mining.router) # 因子挖掘 __all__ = ["api_router"] ================================================ FILE: backend/app/api/v1/agents.py ================================================ """ 智能体 API 路由 - Phase 2 提供辩论功能、执行日志、性能监控等接口 """ import logging import json import asyncio from datetime import datetime, timedelta from typing import List, Optional, Dict, Any, AsyncGenerator from fastapi import APIRouter, Depends, HTTPException, Query, Body from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, desc, or_ from ...core.database import get_db from ...models.news import News from ...models.analysis import Analysis from ...agents import ( create_debate_workflow, create_orchestrator, create_data_collector ) from ...services.llm_service import get_llm_provider from ...services.stock_data_service import stock_data_service logger = logging.getLogger(__name__) router = APIRouter() # ============ 多语言提示词辅助函数 ============ def get_prompts(language: str = "zh") -> Dict[str, str]: """获取多语言提示词""" if language == "en": return { "quick_analyst_system": "You are a professional stock analyst, skilled in quick analysis and decision-making.", "quick_analysis_prompt": """Please provide a quick investment analysis for {stock_name}({stock_code}). Background: {context} Related News: {news} Please quickly provide: 1. Core Viewpoint (one sentence) 2. Bullish Factors (3 points) 3. Bearish Factors (3 points) 4. Investment Recommendation (Buy/Hold/Sell) 5. Risk Warning""", "data_collector_content": "📊 Collected relevant data for {stock_name}: {count} news items, financial data ready.\n\nDebate will begin in {rounds} rounds.", "bull_system": "You are a bullish researcher, skilled at analyzing stocks from a positive perspective. When answering user questions, maintain an optimistic but rational attitude.", "bear_system": "You are a bearish researcher, skilled at identifying risks. When answering user questions, remain cautious and focus on potential risks.", "manager_system": "You are an experienced investment manager, skilled at comprehensive analysis and providing investment advice. Answer user questions objectively and professionally.", "phase_start": "Starting {mode} mode analysis", "phase_analyzing": "Quick analyst is analyzing...", "phase_data_collection": "Data Collector is gathering materials...", "role_quick_analyst": "Quick Analyst", "role_data_collector": "Data Collector", "round_debate": "Round {round}/{max_rounds} debate", "role_bull": "Bull Researcher", "role_bear": "Bear Researcher", "bull_first_round": """You are a bullish researcher participating in a bull vs bear debate about {stock_name}({stock_code}). Background: {context} News: {news} This is Round 1. Please make an opening statement (about 150 words): 1. State your core bullish view 2. Provide 2-3 key arguments""", "bull_subsequent_rounds": """You are a bullish researcher debating with a bearish researcher about {stock_name}. This is Round {round}. The bearish researcher just said: "{bear_last_statement}" Please refute the opponent's arguments and add new points (about 120 words): 1. Point out flaws in the opponent's arguments 2. Add new bullish reasons""", "bear_first_round": """You are a bearish researcher participating in a bull vs bear debate about {stock_name}({stock_code}). Background: {context} News: {news} This is Round 1. Please make an opening statement (about 150 words): 1. State your core bearish view 2. Provide 2-3 key risk points""", "bear_subsequent_rounds": """You are a bearish researcher debating with a bullish researcher about {stock_name}. This is Round {round}. The bullish researcher just said: "{bull_last_statement}" Please refute the opponent's arguments and add new points (about 120 words): 1. Point out flaws in the opponent's arguments 2. Add new risk points""", "manager_decision": """You are an investment manager synthesizing the debate between bullish and bearish researchers to make a final investment decision. Stock: {stock_name}({stock_code}) Bullish Researcher's View: {bull_analysis} Bearish Researcher's View: {bear_analysis} Please provide the final decision (about 200 words): 1. Comprehensive evaluation of both views 2. Investment recommendation (Strongly Recommend/Recommend/Neutral/Avoid/Caution) 3. Reasoning and risk warnings""", } else: # zh (default) return { "quick_analyst_system": "你是一位专业的股票分析师,擅长快速分析和决策。", "quick_analysis_prompt": """请对 {stock_name}({stock_code}) 进行快速投资分析。 背景资料: {context} 相关新闻: {news} 请快速给出: 1. 核心观点(一句话) 2. 看多因素(3点) 3. 看空因素(3点) 4. 投资建议(买入/持有/卖出) 5. 风险提示""", "data_collector_content": "📊 已搜集 {stock_name} 的相关数据:{count} 条新闻,财务数据已就绪。\n\n辩论即将开始,共 {rounds} 轮。", "bull_system": "你是一位看多研究员,擅长从积极角度分析股票。回答用户问题时保持乐观但理性的态度。", "bear_system": "你是一位看空研究员,擅长发现风险。回答用户问题时保持谨慎,重点指出潜在风险。", "manager_system": "你是一位经验丰富的投资经理,擅长综合分析和给出投资建议。回答用户问题时客观、专业。", "phase_start": "开始{mode}模式分析", "phase_analyzing": "快速分析师正在分析...", "phase_data_collection": "数据专员正在搜集资料...", "role_quick_analyst": "快速分析师", "role_data_collector": "数据专员", "round_debate": "第 {round}/{max_rounds} 轮辩论", "role_bull": "看多研究员", "role_bear": "看空研究员", "bull_first_round": """你是看多研究员,正在参与关于 {stock_name}({stock_code}) 的多空辩论。 背景资料: {context} 新闻: {news} 这是第1轮辩论,请做开场陈述(约150字): 1. 表明你的核心看多观点 2. 给出2-3个关键论据""", "bull_subsequent_rounds": """你是看多研究员,正在与看空研究员辩论 {stock_name}。 这是第{round}轮辩论。 对方(看空研究员)刚才说: "{bear_last_statement}" 请反驳对方观点并补充新论据(约120字): 1. 指出对方论据的漏洞 2. 补充新的看多理由""", "bear_first_round": """你是看空研究员,正在参与关于 {stock_name}({stock_code}) 的多空辩论。 背景资料: {context} 新闻: {news} 这是第1轮辩论,请做开场陈述(约150字): 1. 表明你的核心看空观点 2. 给出2-3个关键风险点""", "bear_subsequent_rounds": """你是看空研究员,正在与看多研究员辩论 {stock_name}。 这是第{round}轮辩论。 对方(看多研究员)刚才说: "{bull_last_statement}" 请反驳对方观点并补充新论据(约120字): 1. 指出对方论据的漏洞 2. 补充新的风险点""", "manager_decision": """你是投资经理,正在综合看多和看空研究员的辩论,做出最终投资决策。 股票: {stock_name}({stock_code}) 看多研究员观点: {bull_analysis} 看空研究员观点: {bear_analysis} 请给出最终决策(约200字): 1. 综合评估双方观点 2. 给出投资建议(强烈推荐/推荐/中性/回避/谨慎) 3. 说明理由和风险提示""", } # ============ 模拟数据存储(生产环境应使用数据库) ============ # 存储执行日志 execution_logs: List[Dict[str, Any]] = [] # 存储辩论结果 debate_results: Dict[str, Dict[str, Any]] = {} # ============ Pydantic 模型 ============ class DebateRequest(BaseModel): """辩论请求""" stock_code: str = Field(..., description="股票代码") stock_name: Optional[str] = Field(None, description="股票名称") context: Optional[str] = Field(None, description="额外背景信息") provider: Optional[str] = Field(None, description="LLM提供商") model: Optional[str] = Field(None, description="模型名称") mode: Optional[str] = Field("parallel", description="辩论模式: parallel, realtime_debate, quick_analysis") language: Optional[str] = Field("zh", description="语言设置: zh=中文, en=英文") class DebateResponse(BaseModel): """辩论响应""" success: bool debate_id: Optional[str] = None stock_code: str stock_name: Optional[str] = None mode: Optional[str] = None # 辩论模式 bull_analysis: Optional[Dict[str, Any]] = None bear_analysis: Optional[Dict[str, Any]] = None final_decision: Optional[Dict[str, Any]] = None quick_analysis: Optional[Dict[str, Any]] = None # 快速分析结果 debate_history: Optional[List[Dict[str, Any]]] = None # 实时辩论历史 trajectory: Optional[List[Dict[str, Any]]] = None execution_time: Optional[float] = None error: Optional[str] = None class AgentLogEntry(BaseModel): """智能体日志条目""" id: str timestamp: str agent_name: str agent_role: Optional[str] = None action: str status: str # "started", "completed", "failed" details: Optional[Dict[str, Any]] = None execution_time: Optional[float] = None class AgentMetrics(BaseModel): """智能体性能指标""" total_executions: int successful_executions: int failed_executions: int avg_execution_time: float agent_stats: Dict[str, Dict[str, Any]] recent_activity: List[Dict[str, Any]] class TrajectoryStep(BaseModel): """执行轨迹步骤""" step_id: str step_name: str timestamp: str agent_name: Optional[str] = None input_data: Optional[Dict[str, Any]] = None output_data: Optional[Dict[str, Any]] = None duration: Optional[float] = None status: str class SearchPlanRequest(BaseModel): """生成搜索计划请求""" query: str stock_code: str stock_name: Optional[str] = None class SearchExecuteRequest(BaseModel): """执行搜索计划请求""" plan: Dict[str, Any] # 完整的 SearchPlan 对象 # ============ API 端点 ============ @router.post("/debate", response_model=DebateResponse) async def run_stock_debate( request: DebateRequest, db: AsyncSession = Depends(get_db) ): """ 触发股票辩论分析(Bull vs Bear) - **stock_code**: 股票代码 - **stock_name**: 股票名称(可选) - **context**: 额外背景信息(可选) - **provider**: LLM提供商(可选) - **model**: 模型名称(可选) """ logger.info(f"🎯 收到辩论请求: stock_code={request.stock_code}, stock_name={request.stock_name}") start_time = datetime.utcnow() debate_id = f"debate_{start_time.strftime('%Y%m%d%H%M%S')}_{request.stock_code}" try: # 记录开始 log_entry = { "id": debate_id, "timestamp": start_time.isoformat(), "agent_name": "DebateWorkflow", "action": "debate_start", "status": "started", "details": { "stock_code": request.stock_code, "stock_name": request.stock_name } } execution_logs.append(log_entry) # 标准化股票代码 code = request.stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): short_code = code[2:] else: short_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" logger.info(f"🔍 查询股票 {code} 的关联新闻...") # 获取关联新闻 - 使用 PostgreSQL 原生 ARRAY 查询语法 from sqlalchemy import text stock_codes_filter = text( "stock_codes @> ARRAY[:code1]::varchar[] OR stock_codes @> ARRAY[:code2]::varchar[]" ).bindparams(code1=short_code, code2=code) news_query = select(News).where(stock_codes_filter).order_by(desc(News.publish_time)).limit(10) result = await db.execute(news_query) news_list = result.scalars().all() logger.info(f"📰 找到 {len(news_list)} 条关联新闻") news_data = [ { "id": n.id, "title": n.title, "content": n.content[:500], "sentiment_score": n.sentiment_score, "publish_time": n.publish_time.isoformat() if n.publish_time else None } for n in news_list ] # 如果没有关联新闻,给出警告 if not news_data: logger.warning(f"⚠️ 股票 {code} 没有关联新闻,辩论将基于空数据进行") # 获取财务数据和资金流向(用于增强辩论上下文) logger.info(f"📊 获取 {code} 的财务数据和资金流向...") try: debate_context = await stock_data_service.get_debate_context(code) akshare_context = debate_context.get("summary", "") logger.info(f"📊 获取到额外数据: {akshare_context[:100]}...") except Exception as e: logger.warning(f"⚠️ 获取财务数据失败: {e}") akshare_context = "" # 合并用户提供的上下文和 akshare 数据 full_context = "" if request.context: full_context += f"【用户补充信息】\n{request.context}\n\n" if akshare_context: full_context += f"【实时数据】\n{akshare_context}" # 创建 LLM provider(如果指定了自定义配置) llm_provider = None if request.provider or request.model: logger.info(f"🤖 使用自定义模型: provider={request.provider}, model={request.model}") llm_provider = get_llm_provider( provider=request.provider, model=request.model ) else: logger.info("🤖 使用默认 LLM 配置") # 选择辩论模式 mode = request.mode or "parallel" logger.info(f"⚔️ 开始辩论工作流,模式: {mode}") if mode == "parallel": # 使用原有的并行工作流 workflow = create_debate_workflow(llm_provider) debate_result = await workflow.run_debate( stock_code=code, stock_name=request.stock_name or code, news_list=news_data, context=full_context ) else: # 使用新的编排器(支持 realtime_debate 和 quick_analysis) orchestrator = create_orchestrator(mode=mode, llm_provider=llm_provider) debate_result = await orchestrator.run( stock_code=code, stock_name=request.stock_name or code, context=full_context, news_list=news_data ) end_time = datetime.utcnow() execution_time = (end_time - start_time).total_seconds() # 存储结果 debate_results[debate_id] = debate_result # 记录完成 log_entry = { "id": f"{debate_id}_complete", "timestamp": end_time.isoformat(), "agent_name": "DebateWorkflow", "action": "debate_complete", "status": "completed" if debate_result.get("success") else "failed", "details": { "stock_code": request.stock_code, "rating": debate_result.get("final_decision", {}).get("rating", "unknown") }, "execution_time": execution_time } execution_logs.append(log_entry) if debate_result.get("success"): return DebateResponse( success=True, debate_id=debate_id, stock_code=code, stock_name=request.stock_name, mode=mode, bull_analysis=debate_result.get("bull_analysis"), bear_analysis=debate_result.get("bear_analysis"), final_decision=debate_result.get("final_decision"), quick_analysis=debate_result.get("quick_analysis"), debate_history=debate_result.get("debate_history"), trajectory=debate_result.get("trajectory"), execution_time=execution_time ) else: return DebateResponse( success=False, debate_id=debate_id, stock_code=code, mode=mode, error=debate_result.get("error", "Unknown error") ) except Exception as e: logger.error(f"Debate failed: {e}", exc_info=True) # 记录失败 log_entry = { "id": f"{debate_id}_error", "timestamp": datetime.utcnow().isoformat(), "agent_name": "DebateWorkflow", "action": "debate_error", "status": "failed", "details": {"error": str(e)} } execution_logs.append(log_entry) return DebateResponse( success=False, debate_id=debate_id, stock_code=request.stock_code, error=str(e) ) # ============ SSE 流式辩论 ============ async def generate_debate_stream( stock_code: str, stock_name: str, mode: str, context: str, news_data: List[Dict], llm_provider, language: str = "zh" ) -> AsyncGenerator[str, None]: """ 生成辩论的 SSE 流 事件类型: - phase: 阶段变化 - agent: 智能体发言 - progress: 进度更新 - result: 最终结果 - error: 错误信息 """ debate_id = f"debate_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}" prompts = get_prompts(language) def sse_event(event_type: str, data: Dict) -> str: """格式化 SSE 事件""" return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" try: # 发送开始事件 yield sse_event("phase", { "phase": "start", "message": prompts["phase_start"].format(mode=mode), "debate_id": debate_id }) if mode == "quick_analysis": # 快速分析模式 - 使用流式输出 yield sse_event("phase", {"phase": "analyzing", "message": prompts["phase_analyzing"]}) news_titles = json.dumps([n.get('title', '') for n in news_data[:5]], ensure_ascii=False) prompt = prompts["quick_analysis_prompt"].format( stock_name=stock_name, stock_code=stock_code, context=context[:2000], news=news_titles ) messages = [ {"role": "system", "content": prompts["quick_analyst_system"]}, {"role": "user", "content": prompt} ] full_response = "" for chunk in llm_provider.stream(messages): full_response += chunk yield sse_event("agent", { "agent": "QuickAnalyst", "role": prompts["role_quick_analyst"], "content": chunk, "is_chunk": True }) await asyncio.sleep(0) # 让出控制权 # 发送完成事件 yield sse_event("result", { "success": True, "mode": mode, "quick_analysis": { "analysis": full_response, "success": True }, "execution_time": 0 }) elif mode == "realtime_debate": # 实时辩论模式 - 多轮交锋 max_rounds = 3 # 最大辩论轮数 yield sse_event("phase", {"phase": "data_collection", "message": prompts["phase_data_collection"]}) await asyncio.sleep(0.3) # 数据搜集 yield sse_event("agent", { "agent": "DataCollector", "role": prompts["role_data_collector"], "content": prompts["data_collector_content"].format( stock_name=stock_name, count=len(news_data), rounds=max_rounds ), "is_chunk": False }) # 辩论历史(用于上下文) debate_history = [] bull_full = "" bear_full = "" # 多轮辩论 for round_num in range(1, max_rounds + 1): yield sse_event("phase", { "phase": "debate", "message": prompts["round_debate"].format(round=round_num, max_rounds=max_rounds), "round": round_num, "max_rounds": max_rounds }) # === Bull 发言 === yield sse_event("agent", { "agent": "BullResearcher", "role": prompts["role_bull"], "content": "", "is_start": True, "round": round_num }) if round_num == 1: # 第一轮:开场陈述 news_titles = json.dumps([n.get('title', '') for n in news_data[:3]], ensure_ascii=False) bull_prompt = prompts["bull_first_round"].format( stock_name=stock_name, stock_code=stock_code, context=context[:800], news=news_titles ) else: # 后续轮次:反驳对方 last_bear = debate_history[-1]["content"] if debate_history else "" bull_prompt = prompts["bull_subsequent_rounds"].format( stock_name=stock_name, round=round_num, bear_last_statement=last_bear[:300] ) bull_system_msg = prompts["bull_system"] if language == "en" else "你是一位辩论中的看多研究员。言简意赅,有理有据,语气自信但不傲慢。" bull_messages = [ {"role": "system", "content": bull_system_msg}, {"role": "user", "content": bull_prompt} ] bull_response = "" for chunk in llm_provider.stream(bull_messages): bull_response += chunk yield sse_event("agent", { "agent": "BullResearcher", "role": "看多研究员", "content": chunk, "is_chunk": True, "round": round_num }) await asyncio.sleep(0) round_marker = f"\n\n**【Round {round_num}】**\n" if language == "en" else f"\n\n**【第{round_num}轮】**\n" bull_full += round_marker + bull_response debate_history.append({"agent": "Bull", "round": round_num, "content": bull_response}) yield sse_event("agent", { "agent": "BullResearcher", "role": prompts["role_bull"], "content": "", "is_end": True, "round": round_num }) # === Bear 发言(反驳) === yield sse_event("agent", { "agent": "BearResearcher", "role": prompts["role_bear"], "content": "", "is_start": True, "round": round_num }) if round_num == 1: news_titles = json.dumps([n.get('title', '') for n in news_data[:3]], ensure_ascii=False) bear_prompt = prompts["bear_first_round"].format( stock_name=stock_name, stock_code=stock_code, context=context[:800], news=news_titles ) else: bear_prompt = prompts["bear_subsequent_rounds"].format( stock_name=stock_name, round=round_num, bull_last_statement=bull_response[:300] ) bear_system_msg = prompts["bear_system"] if language == "en" else "你是一位辩论中的看空研究员。言简意赅,善于发现风险,语气谨慎但有说服力。" bear_messages = [ {"role": "system", "content": bear_system_msg}, {"role": "user", "content": bear_prompt} ] bear_response = "" for chunk in llm_provider.stream(bear_messages): bear_response += chunk yield sse_event("agent", { "agent": "BearResearcher", "role": prompts["role_bear"], "content": chunk, "is_chunk": True, "round": round_num }) await asyncio.sleep(0) bear_full += round_marker + bear_response debate_history.append({"agent": "Bear", "round": round_num, "content": bear_response}) yield sse_event("agent", { "agent": "BearResearcher", "role": prompts["role_bear"], "content": "", "is_end": True, "round": round_num }) # === 投资经理总结决策 === decision_msg = "Debate ended, Investment Manager is making final decision..." if language == "en" else "辩论结束,投资经理正在做最终决策..." yield sse_event("phase", {"phase": "decision", "message": decision_msg}) manager_role = "Investment Manager" if language == "en" else "投资经理" yield sse_event("agent", { "agent": "InvestmentManager", "role": manager_role, "content": "", "is_start": True }) # 整理辩论历史 debate_summary = "\n".join([ f"【第{h['round']}轮-{'看多' if h['agent']=='Bull' else '看空'}】{h['content'][:150]}..." for h in debate_history ]) decision_prompt = prompts["manager_decision"].format( stock_name=stock_name, stock_code=stock_code, bull_analysis=bull_full[:1000], bear_analysis=bear_full[:1000] ) manager_system_msg = prompts["manager_system"] if language == "en" else "你是一位经验丰富的投资经理,善于在多空观点中做出理性决策。" decision_messages = [ {"role": "system", "content": manager_system_msg}, {"role": "user", "content": decision_prompt} ] decision = "" for chunk in llm_provider.stream(decision_messages): decision += chunk yield sse_event("agent", { "agent": "InvestmentManager", "role": manager_role, "content": chunk, "is_chunk": True }) await asyncio.sleep(0) yield sse_event("agent", { "agent": "InvestmentManager", "role": manager_role, "content": "", "is_end": True }) # 提取评级 if language == "en": rating = "Neutral" for r in ["Strongly Recommend", "Recommend", "Neutral", "Caution", "Avoid"]: if r in decision: rating = r break else: rating = "中性" for r in ["强烈推荐", "推荐", "中性", "谨慎", "回避"]: if r in decision: rating = r break # 发送完成事件 yield sse_event("result", { "success": True, "mode": mode, "debate_id": debate_id, "total_rounds": max_rounds, "bull_analysis": {"analysis": bull_full.strip(), "success": True, "agent_name": "BullResearcher", "agent_role": prompts["role_bull"]}, "bear_analysis": {"analysis": bear_full.strip(), "success": True, "agent_name": "BearResearcher", "agent_role": prompts["role_bear"]}, "final_decision": {"decision": decision, "rating": rating, "success": True, "agent_name": "InvestmentManager", "agent_role": manager_role}, "debate_history": debate_history }) else: # parallel 模式 - 也使用流式,但并行展示 yield sse_event("phase", {"phase": "parallel_analysis", "message": "Bull/Bear 并行分析中..."}) # 由于是并行,我们交替输出 bull_prompt = f"""你是看多研究员,请从积极角度分析 {stock_name}({stock_code}): 背景资料: {context[:1500]} 新闻: {json.dumps([n.get('title', '') for n in news_data[:5]], ensure_ascii=False)} 请给出完整的看多分析报告。""" bear_prompt = f"""你是看空研究员,请从风险角度分析 {stock_name}({stock_code}): 背景资料: {context[:1500]} 新闻: {json.dumps([n.get('title', '') for n in news_data[:5]], ensure_ascii=False)} 请给出完整的看空分析报告。""" # Bull 流式输出 yield sse_event("agent", {"agent": "BullResearcher", "role": "看多研究员", "content": "", "is_start": True}) bull_analysis = "" for chunk in llm_provider.stream([ {"role": "system", "content": "你是一位乐观但理性的股票研究员。"}, {"role": "user", "content": bull_prompt} ]): bull_analysis += chunk yield sse_event("agent", {"agent": "BullResearcher", "role": "看多研究员", "content": chunk, "is_chunk": True}) await asyncio.sleep(0) yield sse_event("agent", {"agent": "BullResearcher", "role": "看多研究员", "content": "", "is_end": True}) # Bear 流式输出 yield sse_event("agent", {"agent": "BearResearcher", "role": "看空研究员", "content": "", "is_start": True}) bear_analysis = "" for chunk in llm_provider.stream([ {"role": "system", "content": "你是一位谨慎的股票研究员。"}, {"role": "user", "content": bear_prompt} ]): bear_analysis += chunk yield sse_event("agent", {"agent": "BearResearcher", "role": "看空研究员", "content": chunk, "is_chunk": True}) await asyncio.sleep(0) yield sse_event("agent", {"agent": "BearResearcher", "role": "看空研究员", "content": "", "is_end": True}) # 投资经理决策 yield sse_event("phase", {"phase": "decision", "message": "投资经理决策中..."}) yield sse_event("agent", {"agent": "InvestmentManager", "role": "投资经理", "content": "", "is_start": True}) decision_prompt = f"""综合以下多空观点,对 {stock_name} 做出投资决策: 【看多】{bull_analysis[:800]} 【看空】{bear_analysis[:800]} 请给出评级[强烈推荐/推荐/中性/谨慎/回避]和决策理由。""" decision = "" for chunk in llm_provider.stream([ {"role": "system", "content": "你是投资经理。"}, {"role": "user", "content": decision_prompt} ]): decision += chunk yield sse_event("agent", {"agent": "InvestmentManager", "role": "投资经理", "content": chunk, "is_chunk": True}) await asyncio.sleep(0) yield sse_event("agent", {"agent": "InvestmentManager", "role": "投资经理", "content": "", "is_end": True}) rating = "中性" for r in ["强烈推荐", "推荐", "中性", "谨慎", "回避"]: if r in decision: rating = r break yield sse_event("result", { "success": True, "mode": mode, "bull_analysis": {"analysis": bull_analysis, "success": True, "agent_name": "BullResearcher", "agent_role": "看多研究员"}, "bear_analysis": {"analysis": bear_analysis, "success": True, "agent_name": "BearResearcher", "agent_role": "看空研究员"}, "final_decision": {"decision": decision, "rating": rating, "success": True, "agent_name": "InvestmentManager", "agent_role": "投资经理"} }) yield sse_event("phase", {"phase": "complete", "message": "分析完成"}) except Exception as e: logger.error(f"SSE Debate error: {e}", exc_info=True) yield sse_event("error", {"message": str(e)}) @router.post("/debate/stream") async def run_stock_debate_stream( request: DebateRequest, db: AsyncSession = Depends(get_db) ): """ 流式辩论分析(SSE) 使用 Server-Sent Events 实时推送辩论过程 """ logger.info(f"🎯 收到流式辩论请求: stock_code={request.stock_code}, mode={request.mode}") # 标准化股票代码 code = request.stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): short_code = code[2:] else: short_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" # 获取关联新闻 from sqlalchemy import text stock_codes_filter = text( "stock_codes @> ARRAY[:code1]::varchar[] OR stock_codes @> ARRAY[:code2]::varchar[]" ).bindparams(code1=short_code, code2=code) news_query = select(News).where(stock_codes_filter).order_by(desc(News.publish_time)).limit(10) result = await db.execute(news_query) news_list = result.scalars().all() news_data = [ { "id": n.id, "title": n.title, "content": n.content[:500] if n.content else "", "sentiment_score": n.sentiment_score, "publish_time": n.publish_time.isoformat() if n.publish_time else None } for n in news_list ] # 获取额外上下文 try: debate_context = await stock_data_service.get_debate_context(code) akshare_context = debate_context.get("summary", "") except Exception as e: logger.warning(f"获取财务数据失败: {e}") akshare_context = "" full_context = "" if request.context: full_context += f"【用户补充】{request.context}\n\n" if akshare_context: full_context += f"【实时数据】{akshare_context}" # 创建 LLM provider llm_provider = get_llm_provider( provider=request.provider, model=request.model ) if request.provider or request.model else get_llm_provider() mode = request.mode or "parallel" stock_name = request.stock_name or code language = request.language or "zh" return StreamingResponse( generate_debate_stream(code, stock_name, mode, full_context, news_data, llm_provider, language), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # 禁用 nginx 缓冲 } ) # ============ 追问功能 ============ class FollowUpRequest(BaseModel): """追问请求""" stock_code: str = Field(..., description="股票代码") stock_name: Optional[str] = Field(None, description="股票名称") question: str = Field(..., description="用户问题") target_agent: Optional[str] = Field(None, description="目标角色: bull, bear, manager") context: Optional[str] = Field(None, description="之前的辩论摘要") async def generate_followup_stream( stock_code: str, stock_name: str, question: str, target_agent: str, context: str, llm_provider ) -> AsyncGenerator[str, None]: """ 生成追问回复的 SSE 流 """ def sse_event(event_type: str, data: Dict) -> str: return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" # 确定回复角色 agent_config = { 'bull': { 'agent': 'BullResearcher', 'role': '多方辩手', 'system': '你是一位看多研究员,擅长从积极角度分析股票。回答用户问题时保持乐观但理性的态度。' }, 'bear': { 'agent': 'BearResearcher', 'role': '空方辩手', 'system': '你是一位看空研究员,擅长发现风险。回答用户问题时保持谨慎,重点指出潜在风险。' }, 'manager': { 'agent': 'InvestmentManager', 'role': '投资经理', 'system': '你是一位经验丰富的投资经理,擅长综合分析和给出投资建议。回答用户问题时客观、专业。' } } config = agent_config.get(target_agent, agent_config['manager']) try: yield sse_event("agent", { "agent": config['agent'], "role": config['role'], "content": "", "is_start": True }) prompt = f"""你正在参与关于 {stock_name}({stock_code}) 的投资讨论。 之前的讨论背景: {context[:1500] if context else '暂无'} 用户现在问你: "{question}" 请以{config['role']}的身份回答(约150-200字):""" messages = [ {"role": "system", "content": config['system']}, {"role": "user", "content": prompt} ] full_response = "" for chunk in llm_provider.stream(messages): full_response += chunk yield sse_event("agent", { "agent": config['agent'], "role": config['role'], "content": chunk, "is_chunk": True }) await asyncio.sleep(0) yield sse_event("agent", { "agent": config['agent'], "role": config['role'], "content": "", "is_end": True }) yield sse_event("complete", {"success": True}) except Exception as e: logger.error(f"Followup error: {e}", exc_info=True) yield sse_event("error", {"message": str(e)}) @router.post("/debate/followup") async def debate_followup(request: FollowUpRequest): """ 辩论追问(SSE) 用户可以在辩论结束后继续提问 - 默认由投资经理回答 - 如果问题中包含 @多方 或 @bull,由多方辩手回答 - 如果问题中包含 @空方 或 @bear,由空方辩手回答 - 如果问题中包含 @数据专员,则生成搜索计划(不直接回答) """ logger.info(f"🎯 收到追问请求: {request.question[:50]}...") # 解析目标角色 question = request.question target = request.target_agent or 'manager' # 1. 检查是否提及数据专员(确认优先模式) if '@数据专员' in question or target == 'data_collector': logger.info("🔍 检测到数据专员提及,生成搜索计划...") # 移除提及词 clean_question = question.replace('@数据专员', '').strip() # 创建数据专员 data_collector = create_data_collector() # 生成计划 plan = await data_collector.generate_search_plan( query=clean_question, stock_code=request.stock_code, stock_name=request.stock_name or request.stock_code ) # 使用 SSE 返回计划事件 async def generate_plan_stream(): # Pydantic V2: 使用 model_dump_json() 或 json.dumps(model_dump()) plan_json = json.dumps(plan.model_dump(), ensure_ascii=False) yield f"event: task_plan\ndata: {plan_json}\n\n" yield "event: complete\ndata: {\"success\": true}\n\n" return StreamingResponse( generate_plan_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) # 2. 普通追问逻辑 # 从问题中解析 @ 提及 if '@多方' in question or '@bull' in question.lower() or '@看多' in question: target = 'bull' question = question.replace('@多方', '').replace('@bull', '').replace('@Bull', '').replace('@看多', '').strip() elif '@空方' in question or '@bear' in question.lower() or '@看空' in question: target = 'bear' question = question.replace('@空方', '').replace('@bear', '').replace('@Bear', '').replace('@看空', '').strip() elif '@经理' in question or '@manager' in question.lower() or '@投资经理' in question: target = 'manager' question = question.replace('@经理', '').replace('@manager', '').replace('@Manager', '').replace('@投资经理', '').strip() # 创建 LLM provider llm_provider = get_llm_provider() stock_name = request.stock_name or request.stock_code return StreamingResponse( generate_followup_stream( request.stock_code, stock_name, question, target, request.context or "", llm_provider ), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) @router.post("/search/execute") async def execute_search(request: SearchExecuteRequest): """ 执行确认后的搜索计划(SSE) """ from ...agents.data_collector_v2 import SearchPlan logger.info(f"🚀 收到搜索执行请求: {request.plan.get('plan_id')}") try: # 反序列化计划 plan = SearchPlan(**request.plan) async def generate_search_results(): yield f"event: phase\ndata: {json.dumps({'phase': 'executing', 'message': '正在执行搜索任务...'}, ensure_ascii=False)}\n\n" data_collector = create_data_collector() # 执行计划 results = await data_collector.execute_search_plan(plan) # 发送结果事件 yield f"event: agent\ndata: {json.dumps({'agent': 'DataCollector', 'role': '数据专员', 'content': results.get('summary', ''), 'is_chunk': False}, ensure_ascii=False)}\n\n" yield f"event: result\ndata: {json.dumps(results, ensure_ascii=False)}\n\n" yield "event: complete\ndata: {\"success\": true}\n\n" return StreamingResponse( generate_search_results(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) except Exception as e: logger.error(f"执行搜索计划失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/debate/{debate_id}", response_model=DebateResponse) async def get_debate_result(debate_id: str): """ 获取辩论结果 - **debate_id**: 辩论ID """ if debate_id not in debate_results: raise HTTPException(status_code=404, detail="Debate not found") result = debate_results[debate_id] return DebateResponse( success=result.get("success", False), debate_id=debate_id, stock_code=result.get("stock_code", ""), stock_name=result.get("stock_name"), bull_analysis=result.get("bull_analysis"), bear_analysis=result.get("bear_analysis"), final_decision=result.get("final_decision"), trajectory=result.get("trajectory"), execution_time=result.get("execution_time") ) @router.get("/logs", response_model=List[AgentLogEntry]) async def get_agent_logs( limit: int = Query(50, le=200), agent_name: Optional[str] = Query(None, description="按智能体名称筛选"), status: Optional[str] = Query(None, description="按状态筛选: started, completed, failed") ): """ 获取智能体执行日志 - **limit**: 返回数量限制 - **agent_name**: 按智能体名称筛选 - **status**: 按状态筛选 """ logs = execution_logs.copy() # 筛选 if agent_name: logs = [log for log in logs if log.get("agent_name") == agent_name] if status: logs = [log for log in logs if log.get("status") == status] # 按时间倒序 logs.sort(key=lambda x: x.get("timestamp", ""), reverse=True) # 限制数量 logs = logs[:limit] return [AgentLogEntry(**log) for log in logs] @router.get("/metrics", response_model=AgentMetrics) async def get_agent_metrics(): """ 获取智能体性能指标 """ total = len(execution_logs) successful = len([log for log in execution_logs if log.get("status") == "completed"]) failed = len([log for log in execution_logs if log.get("status") == "failed"]) # 计算平均执行时间 execution_times = [ log.get("execution_time", 0) for log in execution_logs if log.get("execution_time") is not None ] avg_time = sum(execution_times) / len(execution_times) if execution_times else 0 # 按智能体统计 agent_stats = {} for log in execution_logs: agent_name = log.get("agent_name", "Unknown") if agent_name not in agent_stats: agent_stats[agent_name] = { "total": 0, "successful": 0, "failed": 0, "avg_time": 0, "times": [] } agent_stats[agent_name]["total"] += 1 if log.get("status") == "completed": agent_stats[agent_name]["successful"] += 1 elif log.get("status") == "failed": agent_stats[agent_name]["failed"] += 1 if log.get("execution_time"): agent_stats[agent_name]["times"].append(log["execution_time"]) # 计算每个智能体的平均时间 for agent_name, stats in agent_stats.items(): if stats["times"]: stats["avg_time"] = sum(stats["times"]) / len(stats["times"]) del stats["times"] # 不返回原始时间列表 # 最近活动 recent_logs = sorted( execution_logs, key=lambda x: x.get("timestamp", ""), reverse=True )[:10] recent_activity = [ { "timestamp": log.get("timestamp"), "agent_name": log.get("agent_name"), "action": log.get("action"), "status": log.get("status") } for log in recent_logs ] return AgentMetrics( total_executions=total, successful_executions=successful, failed_executions=failed, avg_execution_time=round(avg_time, 2), agent_stats=agent_stats, recent_activity=recent_activity ) @router.get("/trajectory/{debate_id}", response_model=List[TrajectoryStep]) async def get_debate_trajectory(debate_id: str): """ 获取辩论执行轨迹 - **debate_id**: 辩论ID """ if debate_id not in debate_results: raise HTTPException(status_code=404, detail="Debate not found") result = debate_results[debate_id] trajectory = result.get("trajectory", []) steps = [] for i, step in enumerate(trajectory): steps.append(TrajectoryStep( step_id=f"{debate_id}_step_{i}", step_name=step.get("step", "unknown"), timestamp=step.get("timestamp", ""), agent_name=step.get("data", {}).get("agent"), input_data=None, # 可以扩展 output_data=step.get("data"), duration=None, status="completed" )) return steps @router.delete("/logs") async def clear_logs(): """ 清空执行日志(仅用于开发测试) """ global execution_logs count = len(execution_logs) execution_logs = [] return {"message": f"Cleared {count} logs"} @router.get("/available") async def get_available_agents(): """ 获取可用的智能体列表 """ return { "agents": [ { "name": "NewsAnalyst", "role": "金融新闻分析师", "description": "分析金融新闻的情感、影响和关键信息", "status": "active" }, { "name": "BullResearcher", "role": "看多研究员", "description": "从积极角度分析股票,发现投资机会", "status": "active" }, { "name": "BearResearcher", "role": "看空研究员", "description": "从风险角度分析股票,识别潜在问题", "status": "active" }, { "name": "InvestmentManager", "role": "投资经理", "description": "综合多方观点,做出投资决策", "status": "active" }, { "name": "SearchAnalyst", "role": "搜索分析师", "description": "动态获取数据,支持 AkShare、BochaAI、网页搜索等", "status": "active" } ], "workflows": [ { "name": "NewsAnalysisWorkflow", "description": "新闻分析工作流:爬取 -> 清洗 -> 情感分析", "agents": ["NewsAnalyst"], "status": "active" }, { "name": "InvestmentDebateWorkflow", "description": "投资辩论工作流:Bull vs Bear 多智能体辩论", "agents": ["BullResearcher", "BearResearcher", "InvestmentManager"], "status": "active" } ] } # ============ 辩论历史 API ============ class DebateHistoryRequest(BaseModel): """保存辩论历史请求""" stock_code: str = Field(..., description="股票代码") sessions: List[Dict[str, Any]] = Field(..., description="会话列表") class DebateHistoryResponse(BaseModel): """辩论历史响应""" success: bool stock_code: str sessions: List[Dict[str, Any]] = [] message: Optional[str] = None @router.get("/debate/history/{stock_code}", response_model=DebateHistoryResponse) async def get_debate_history( stock_code: str, limit: int = Query(10, le=50, description="返回会话数量限制"), db: AsyncSession = Depends(get_db) ): """ 获取股票的辩论历史 - **stock_code**: 股票代码 - **limit**: 返回数量限制(默认10,最大50) """ from ...models.debate_history import DebateHistory try: # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" # 查询历史记录 query = select(DebateHistory).where( DebateHistory.stock_code == code ).order_by(desc(DebateHistory.updated_at)).limit(limit) result = await db.execute(query) histories = result.scalars().all() sessions = [] for h in histories: sessions.append({ "id": h.session_id, "stockCode": h.stock_code, "stockName": h.stock_name, "mode": h.mode, "messages": h.messages, "createdAt": h.created_at.isoformat() if h.created_at else None, "updatedAt": h.updated_at.isoformat() if h.updated_at else None }) return DebateHistoryResponse( success=True, stock_code=code, sessions=sessions ) except Exception as e: logger.error(f"获取辩论历史失败: {e}", exc_info=True) return DebateHistoryResponse( success=False, stock_code=stock_code, message=str(e) ) @router.post("/debate/history", response_model=DebateHistoryResponse) async def save_debate_history( request: DebateHistoryRequest, db: AsyncSession = Depends(get_db) ): """ 保存辩论历史 - **stock_code**: 股票代码 - **sessions**: 会话列表 """ from ...models.debate_history import DebateHistory try: # 标准化股票代码 code = request.stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" saved_count = 0 for session_data in request.sessions: session_id = session_data.get("id") if not session_id: continue messages = session_data.get("messages", []) logger.info(f"📥 Processing session {session_id}: {len(messages)} messages") logger.info(f"📥 Message roles: {[m.get('role') for m in messages]}") # 检查是否已存在 existing_query = select(DebateHistory).where( DebateHistory.session_id == session_id ) existing_result = await db.execute(existing_query) existing = existing_result.scalar_one_or_none() if existing: # 更新现有记录 logger.info(f"📥 Updating existing session, old messages: {len(existing.messages)}, new: {len(messages)}") existing.messages = messages existing.mode = session_data.get("mode") existing.updated_at = datetime.utcnow() else: # 解析 created_at,确保是 naive datetime(去掉时区信息) created_at_str = session_data.get("createdAt") if created_at_str: # 处理 ISO 格式字符串,移除末尾的 'Z' 并转换 if created_at_str.endswith('Z'): created_at_str = created_at_str[:-1] + '+00:00' parsed_dt = datetime.fromisoformat(created_at_str) # 转换为 naive datetime (去掉时区信息) if parsed_dt.tzinfo is not None: created_at = parsed_dt.replace(tzinfo=None) else: created_at = parsed_dt else: created_at = datetime.utcnow() # 创建新记录 new_history = DebateHistory( session_id=session_id, stock_code=code, stock_name=session_data.get("stockName"), mode=session_data.get("mode"), messages=session_data.get("messages", []), created_at=created_at, updated_at=datetime.utcnow() ) db.add(new_history) saved_count += 1 await db.commit() logger.info(f"保存了 {saved_count} 个辩论会话到数据库") return DebateHistoryResponse( success=True, stock_code=code, message=f"成功保存 {saved_count} 个会话" ) except Exception as e: logger.error(f"保存辩论历史失败: {e}", exc_info=True) await db.rollback() return DebateHistoryResponse( success=False, stock_code=request.stock_code, message=str(e) ) @router.delete("/debate/history/{stock_code}") async def delete_debate_history( stock_code: str, session_id: Optional[str] = Query(None, description="删除指定会话,不传则删除所有"), db: AsyncSession = Depends(get_db) ): """ 删除辩论历史 - **stock_code**: 股票代码 - **session_id**: 会话ID(可选,不传则删除该股票的所有历史) """ from ...models.debate_history import DebateHistory from sqlalchemy import delete try: # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" if session_id: # 删除指定会话 stmt = delete(DebateHistory).where( DebateHistory.session_id == session_id ) else: # 删除该股票的所有会话 stmt = delete(DebateHistory).where( DebateHistory.stock_code == code ) result = await db.execute(stmt) await db.commit() deleted_count = result.rowcount return { "success": True, "stock_code": code, "deleted_count": deleted_count, "message": f"删除了 {deleted_count} 条记录" } except Exception as e: logger.error(f"删除辩论历史失败: {e}", exc_info=True) await db.rollback() return { "success": False, "stock_code": stock_code, "message": str(e) } ================================================ FILE: backend/app/api/v1/alpha_mining.py ================================================ """ Alpha Mining REST API 提供因子挖掘相关的 HTTP 接口。 Endpoints: - POST /alpha-mining/mine - 启动因子挖掘任务 - POST /alpha-mining/mine/stream - SSE 流式训练进度 - POST /alpha-mining/evaluate - 评估因子表达式 - POST /alpha-mining/generate - 生成候选因子 - POST /alpha-mining/compare-sentiment - 情感融合效果对比 - POST /alpha-mining/agent-demo - AgenticX Agent 调用演示 - GET /alpha-mining/factors - 获取已发现的因子列表 - GET /alpha-mining/status - 获取挖掘状态 - GET /alpha-mining/operators - 获取操作符列表 """ from fastapi import APIRouter, HTTPException, BackgroundTasks from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any, AsyncGenerator from datetime import datetime import logging import uuid import asyncio import json import queue import threading logger = logging.getLogger(__name__) router = APIRouter(prefix="/alpha-mining", tags=["Alpha Mining"]) # 存储挖掘任务状态 _mining_tasks: Dict[str, Dict[str, Any]] = {} _discovered_factors: List[Dict[str, Any]] = [] # ============================================================================ # Request/Response Models # ============================================================================ class MineRequest(BaseModel): """因子挖掘请求""" stock_code: Optional[str] = Field(None, description="股票代码") num_steps: int = Field(100, ge=1, le=10000, description="训练步数") use_sentiment: bool = Field(True, description="是否使用情感特征") batch_size: int = Field(16, ge=1, le=128, description="批量大小") class EvaluateRequest(BaseModel): """因子评估请求""" formula: str = Field(..., description="因子表达式") stock_code: Optional[str] = Field(None, description="股票代码") class GenerateRequest(BaseModel): """因子生成请求""" batch_size: int = Field(10, ge=1, le=100, description="生成数量") max_len: int = Field(8, ge=4, le=16, description="最大表达式长度") class FactorResponse(BaseModel): """因子响应""" formula: List[int] = Field(..., description="Token 序列") formula_str: str = Field(..., description="表达式字符串") sortino: float = Field(..., description="Sortino Ratio") sharpe: Optional[float] = Field(None, description="Sharpe Ratio") ic: Optional[float] = Field(None, description="IC") discovered_at: Optional[str] = Field(None, description="发现时间") class MineResponse(BaseModel): """挖掘响应""" success: bool task_id: str message: str best_factor: Optional[FactorResponse] = None class EvaluateResponse(BaseModel): """评估响应""" success: bool formula: str metrics: Optional[Dict[str, float]] = None error: Optional[str] = None class GenerateResponse(BaseModel): """生成响应""" success: bool generated: int valid: int factors: List[Dict[str, Any]] class TaskStatusResponse(BaseModel): """任务状态响应""" task_id: str status: str # pending, running, completed, failed progress: float # 0-100 result: Optional[Dict[str, Any]] = None error: Optional[str] = None started_at: Optional[str] = None completed_at: Optional[str] = None class SentimentCompareRequest(BaseModel): """情感融合对比请求""" num_steps: int = Field(50, ge=10, le=500, description="训练步数") batch_size: int = Field(16, ge=1, le=64, description="批量大小") class SentimentCompareResponse(BaseModel): """情感融合对比响应""" success: bool with_sentiment: Dict[str, Any] = Field(..., description="含情感特征的结果") without_sentiment: Dict[str, Any] = Field(..., description="不含情感特征的结果") improvement: Dict[str, float] = Field(..., description="改进幅度") class AgentDemoRequest(BaseModel): """Agent 调用演示请求""" stock_code: Optional[str] = Field(None, description="股票代码") num_steps: int = Field(30, ge=10, le=200, description="训练步数") use_sentiment: bool = Field(True, description="使用情感特征") class AgentDemoResponse(BaseModel): """Agent 调用演示响应""" success: bool agent_name: str tool_name: str input_params: Dict[str, Any] output: Optional[Dict[str, Any]] = None execution_time: float logs: List[str] = [] # ============================================================================ # Helper Functions # ============================================================================ def _get_alpha_mining_components(): """获取 Alpha Mining 组件""" try: from ...alpha_mining import ( AlphaMiningConfig, FactorVocab, FactorVM, AlphaGenerator, AlphaTrainer, FactorEvaluator, generate_mock_data ) config = AlphaMiningConfig() vocab = FactorVocab() vm = FactorVM(vocab=vocab) generator = AlphaGenerator(vocab=vocab, config=config) evaluator = FactorEvaluator(config=config) return { "config": config, "vocab": vocab, "vm": vm, "generator": generator, "evaluator": evaluator, "generate_mock_data": generate_mock_data } except ImportError as e: logger.error(f"Failed to import Alpha Mining: {e}") raise HTTPException( status_code=503, detail="Alpha Mining module not available" ) async def _run_mining_task(task_id: str, request: MineRequest): """后台运行挖掘任务""" global _discovered_factors try: _mining_tasks[task_id]["status"] = "running" _mining_tasks[task_id]["started_at"] = datetime.utcnow().isoformat() components = _get_alpha_mining_components() from ...alpha_mining import AlphaTrainer # 准备数据 features, returns = components["generate_mock_data"]( num_samples=50, num_features=6, time_steps=252, seed=42 ) # 创建训练器 config = components["config"] config.batch_size = request.batch_size trainer = AlphaTrainer( generator=components["generator"], vocab=components["vocab"], config=config ) # 训练 result = trainer.train( features=features, returns=returns, num_steps=request.num_steps, progress_bar=False ) # 保存结果 if result["best_formula"]: factor_info = { "formula": result["best_formula"], "formula_str": result["best_formula_str"], "sortino": result["best_score"], "discovered_at": datetime.utcnow().isoformat(), "task_id": task_id, "stock_code": request.stock_code } _discovered_factors.append(factor_info) # 保持只存储最优的 100 个 _discovered_factors.sort(key=lambda x: x.get("sortino", 0), reverse=True) _discovered_factors = _discovered_factors[:100] _mining_tasks[task_id]["status"] = "completed" _mining_tasks[task_id]["progress"] = 100 _mining_tasks[task_id]["completed_at"] = datetime.utcnow().isoformat() _mining_tasks[task_id]["result"] = { "best_factor": result["best_formula_str"], "best_score": result["best_score"], "total_steps": result["total_steps"] } except Exception as e: logger.error(f"Mining task {task_id} failed: {e}") _mining_tasks[task_id]["status"] = "failed" _mining_tasks[task_id]["error"] = str(e) _mining_tasks[task_id]["completed_at"] = datetime.utcnow().isoformat() # ============================================================================ # API Endpoints # ============================================================================ @router.post("/mine", response_model=MineResponse) async def mine_factors( request: MineRequest, background_tasks: BackgroundTasks ): """ 启动因子挖掘任务 使用强化学习自动发现有效的交易因子。 任务在后台执行,可通过 /status/{task_id} 查询进度。 """ task_id = str(uuid.uuid4()) # 初始化任务状态 _mining_tasks[task_id] = { "status": "pending", "progress": 0, "request": request.dict(), "created_at": datetime.utcnow().isoformat() } # 添加后台任务 background_tasks.add_task(_run_mining_task, task_id, request) return MineResponse( success=True, task_id=task_id, message=f"因子挖掘任务已启动,预计 {request.num_steps} 步训练" ) @router.post("/mine/stream") async def mine_factors_stream(request: MineRequest): """ SSE 流式返回训练进度 实时推送每步训练指标,包括 loss、reward、best_score 等。 前端可使用 EventSource 订阅。 """ async def event_generator() -> AsyncGenerator[str, None]: try: components = _get_alpha_mining_components() from ...alpha_mining import AlphaTrainer # 准备数据 features, returns = components["generate_mock_data"]( num_samples=50, num_features=6, time_steps=252, seed=42 ) # 创建训练器 config = components["config"] config.batch_size = request.batch_size trainer = AlphaTrainer( generator=components["generator"], vocab=components["vocab"], config=config ) # 使用队列在线程间传递数据 metrics_queue: queue.Queue = queue.Queue() training_complete = threading.Event() training_error: List[str] = [] def step_callback(metrics: Dict[str, Any]): """每步训练回调,将指标放入队列""" metrics_queue.put(metrics) def run_training(): """在后台线程中运行训练""" try: trainer.train( features=features, returns=returns, num_steps=request.num_steps, progress_bar=False, step_callback=step_callback ) except Exception as e: training_error.append(str(e)) finally: training_complete.set() # 启动训练线程 training_thread = threading.Thread(target=run_training) training_thread.start() # 发送开始事件 yield f"event: start\ndata: {json.dumps({'status': 'started', 'total_steps': request.num_steps})}\n\n" # 流式发送训练进度 while not training_complete.is_set() or not metrics_queue.empty(): try: metrics = metrics_queue.get(timeout=0.1) event_data = { "step": metrics.get("step", 0), "progress": metrics.get("progress", 0), "loss": round(metrics.get("loss", 0), 6), "avg_reward": round(metrics.get("avg_reward", 0), 6), "max_reward": round(metrics.get("max_reward", 0), 6), "valid_ratio": round(metrics.get("valid_ratio", 0), 4), "best_score": round(metrics.get("best_score", -999), 6), "best_formula": metrics.get("best_formula", ""), } yield f"event: progress\ndata: {json.dumps(event_data)}\n\n" except queue.Empty: continue # 等待训练线程结束 training_thread.join(timeout=5) # 发送完成事件 if training_error: yield f"event: error\ndata: {json.dumps({'error': training_error[0]})}\n\n" else: final_result = { "status": "completed", "best_score": round(trainer.best_score, 6), "best_formula": trainer.best_formula_str, "total_steps": trainer.step_count, } yield f"event: complete\ndata: {json.dumps(final_result)}\n\n" # 保存发现的因子 if trainer.best_formula: _discovered_factors.append({ "formula": trainer.best_formula, "formula_str": trainer.best_formula_str, "sortino": trainer.best_score, "discovered_at": datetime.utcnow().isoformat(), "stock_code": request.stock_code }) except Exception as e: logger.error(f"SSE streaming error: {e}") yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", } ) @router.post("/compare-sentiment", response_model=SentimentCompareResponse) async def compare_sentiment_effect(request: SentimentCompareRequest): """ 对比有/无情感特征的因子挖掘效果 分别使用纯技术特征和技术+情感特征进行因子挖掘, 对比最终效果差异。 """ try: components = _get_alpha_mining_components() from ...alpha_mining import AlphaTrainer, AlphaMiningConfig results = {} for use_sentiment in [False, True]: # 准备数据 num_features = 6 if use_sentiment else 4 # 4个技术特征 + 2个情感特征 features, returns = components["generate_mock_data"]( num_samples=50, num_features=num_features, time_steps=252, seed=42 ) # 训练 config = AlphaMiningConfig() config.batch_size = request.batch_size trainer = AlphaTrainer( generator=components["generator"], vocab=components["vocab"], config=config ) result = trainer.train( features=features, returns=returns, num_steps=request.num_steps, progress_bar=False ) key = "with_sentiment" if use_sentiment else "without_sentiment" results[key] = { "best_score": round(result["best_score"], 6), "best_formula": result["best_formula_str"], "total_steps": result["total_steps"], "num_features": num_features, } # 计算改进幅度 with_score = results["with_sentiment"]["best_score"] without_score = results["without_sentiment"]["best_score"] if without_score != 0: improvement_pct = (with_score - without_score) / abs(without_score) * 100 else: improvement_pct = 0 if with_score == 0 else 100 improvement = { "score_diff": round(with_score - without_score, 6), "improvement_pct": round(improvement_pct, 2), } return SentimentCompareResponse( success=True, with_sentiment=results["with_sentiment"], without_sentiment=results["without_sentiment"], improvement=improvement ) except Exception as e: logger.error(f"Sentiment comparison failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/agent-demo", response_model=AgentDemoResponse) async def agent_alpha_mining_demo(request: AgentDemoRequest): """ 演示 AgenticX Agent 调用 AlphaMiningTool 展示如何通过 Agent 接口调用因子挖掘功能。 """ import time start_time = time.time() logs = [] try: logs.append(f"[{datetime.utcnow().isoformat()}] Agent 初始化...") logs.append(f"[{datetime.utcnow().isoformat()}] 调用 AlphaMiningTool...") # 模拟 Agent 调用 components = _get_alpha_mining_components() from ...alpha_mining import AlphaTrainer input_params = { "stock_code": request.stock_code, "num_steps": request.num_steps, "use_sentiment": request.use_sentiment, } logs.append(f"[{datetime.utcnow().isoformat()}] Tool 参数: {json.dumps(input_params)}") # 准备数据 features, returns = components["generate_mock_data"]( num_samples=50, num_features=6 if request.use_sentiment else 4, time_steps=252, seed=42 ) logs.append(f"[{datetime.utcnow().isoformat()}] 数据准备完成") # 训练 trainer = AlphaTrainer( generator=components["generator"], vocab=components["vocab"], config=components["config"] ) logs.append(f"[{datetime.utcnow().isoformat()}] 开始训练...") result = trainer.train( features=features, returns=returns, num_steps=request.num_steps, progress_bar=False ) logs.append(f"[{datetime.utcnow().isoformat()}] 训练完成") execution_time = time.time() - start_time output = { "best_formula": result["best_formula_str"], "best_score": round(result["best_score"], 6), "total_steps": result["total_steps"], } logs.append(f"[{datetime.utcnow().isoformat()}] 返回结果: {json.dumps(output)}") return AgentDemoResponse( success=True, agent_name="QuantitativeAgent", tool_name="AlphaMiningTool", input_params=input_params, output=output, execution_time=round(execution_time, 2), logs=logs ) except Exception as e: execution_time = time.time() - start_time logs.append(f"[{datetime.utcnow().isoformat()}] 错误: {str(e)}") return AgentDemoResponse( success=False, agent_name="QuantitativeAgent", tool_name="AlphaMiningTool", input_params=request.dict(), output=None, execution_time=round(execution_time, 2), logs=logs ) @router.post("/evaluate", response_model=EvaluateResponse) async def evaluate_factor(request: EvaluateRequest): """ 评估因子表达式 对指定的因子表达式进行回测评估,返回各项指标。 """ try: components = _get_alpha_mining_components() vm = components["vm"] evaluator = components["evaluator"] # 解析公式 tokens = [] parts = request.formula.replace("(", " ").replace(")", " ").replace(",", " ").split() for part in parts: part = part.strip() if not part: continue try: token = vm.vocab.name_to_token(part) tokens.append(token) except (ValueError, KeyError): continue if not tokens: return EvaluateResponse( success=False, formula=request.formula, error="无法解析因子表达式" ) # 准备数据 features, returns = components["generate_mock_data"]( num_samples=50, num_features=6, time_steps=252, seed=42 ) # 执行因子 factor = vm.execute(tokens, features) if factor is None: return EvaluateResponse( success=False, formula=request.formula, error="因子执行失败" ) # 评估 metrics = evaluator.evaluate(factor, returns) return EvaluateResponse( success=True, formula=request.formula, metrics={ "sortino_ratio": metrics["sortino_ratio"], "sharpe_ratio": metrics["sharpe_ratio"], "ic": metrics["ic"], "rank_ic": metrics["rank_ic"], "max_drawdown": metrics["max_drawdown"], "turnover": metrics["turnover"], "total_return": metrics["total_return"], "win_rate": metrics["win_rate"] } ) except Exception as e: logger.error(f"Factor evaluation failed: {e}") return EvaluateResponse( success=False, formula=request.formula, error=str(e) ) @router.post("/generate", response_model=GenerateResponse) async def generate_factors(request: GenerateRequest): """ 生成候选因子 使用训练好的模型生成一批候选因子表达式。 """ try: components = _get_alpha_mining_components() generator = components["generator"] vm = components["vm"] evaluator = components["evaluator"] # 生成因子 formulas, _ = generator.generate( batch_size=request.batch_size, max_len=request.max_len ) # 准备数据用于评估 features, returns = components["generate_mock_data"]( num_samples=50, num_features=6, time_steps=252, seed=42 ) # 评估每个因子 results = [] for formula in formulas: factor = vm.execute(formula, features) if factor is not None and factor.std() > 1e-6: try: metrics = evaluator.evaluate(factor, returns) results.append({ "formula": formula, "formula_str": vm.decode(formula), "sortino": round(metrics["sortino_ratio"], 4), "ic": round(metrics["ic"], 4) }) except Exception: continue # 按 Sortino 排序 results.sort(key=lambda x: x["sortino"], reverse=True) return GenerateResponse( success=True, generated=len(formulas), valid=len(results), factors=results[:10] ) except Exception as e: logger.error(f"Factor generation failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/factors") async def get_factors( top_k: int = 10, stock_code: Optional[str] = None ): """ 获取已发现的因子列表 返回按 Sortino Ratio 排序的最优因子。 """ factors = _discovered_factors.copy() # 按股票代码过滤 if stock_code: factors = [f for f in factors if f.get("stock_code") == stock_code] # 取 top_k factors = factors[:top_k] return { "success": True, "total": len(_discovered_factors), "returned": len(factors), "factors": factors } @router.get("/status/{task_id}", response_model=TaskStatusResponse) async def get_task_status(task_id: str): """ 获取挖掘任务状态 """ if task_id not in _mining_tasks: raise HTTPException(status_code=404, detail="Task not found") task = _mining_tasks[task_id] return TaskStatusResponse( task_id=task_id, status=task["status"], progress=task.get("progress", 0), result=task.get("result"), error=task.get("error"), started_at=task.get("started_at"), completed_at=task.get("completed_at") ) @router.get("/operators") async def get_operators(): """ 获取支持的操作符列表 """ try: from ...alpha_mining.dsl.ops import OPS_CONFIG, get_op_names from ...alpha_mining.dsl.vocab import FEATURES operators = [] for name, func, arity in OPS_CONFIG: operators.append({ "name": name, "arity": arity, "description": func.__doc__ or "" }) return { "success": True, "features": FEATURES, "operators": operators } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.delete("/tasks/{task_id}") async def delete_task(task_id: str): """ 删除任务记录 """ if task_id not in _mining_tasks: raise HTTPException(status_code=404, detail="Task not found") del _mining_tasks[task_id] return {"success": True, "message": f"Task {task_id} deleted"} ================================================ FILE: backend/app/api/v1/analysis.py ================================================ """ 分析任务 API 路由 """ import logging import asyncio import json from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Body, Request from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from ...core.database import get_db from ...models.database import AsyncSessionLocal from ...services.analysis_service import get_analysis_service logger = logging.getLogger(__name__) router = APIRouter() # Pydantic 模型 class AnalysisRequest(BaseModel): """分析请求模型""" provider: Optional[str] = Field(default=None, description="LLM提供商 (bailian/openai/deepseek/kimi/zhipu)") model: Optional[str] = Field(default=None, description="模型名称") class AnalysisResponse(BaseModel): """分析响应模型""" success: bool analysis_id: Optional[int] = None news_id: int sentiment: Optional[str] = None sentiment_score: Optional[float] = None confidence: Optional[float] = None summary: Optional[str] = None execution_time: Optional[float] = None error: Optional[str] = None class AnalysisDetailResponse(BaseModel): """分析详情响应模型""" model_config = {"from_attributes": True} id: int news_id: int agent_name: str agent_role: Optional[str] = None analysis_result: str summary: Optional[str] = None sentiment: Optional[str] = None sentiment_score: Optional[float] = None confidence: Optional[float] = None execution_time: Optional[float] = None created_at: str class BatchAnalyzeRequest(BaseModel): """批量分析请求模型""" news_ids: List[int] = Field(..., description="要分析的新闻ID列表") provider: Optional[str] = Field(default=None, description="LLM提供商") model: Optional[str] = Field(default=None, description="模型名称") class BatchAnalyzeResponse(BaseModel): """批量分析响应模型""" success: bool message: str total_count: int success_count: int failed_count: int results: List[AnalysisResponse] # 后台任务:执行分析 async def run_analysis_task(news_id: int, db: AsyncSession): """ 后台任务:执行新闻分析 """ try: analysis_service = get_analysis_service() result = await analysis_service.analyze_news(news_id, db) logger.info(f"Analysis task completed for news {news_id}: {result}") except Exception as e: logger.error(f"Analysis task failed for news {news_id}: {e}") # API 端点 # 注意:具体路径(如 /news/batch)必须在参数路径(如 /news/{news_id})之前定义 # 否则 FastAPI 会把 "batch" 当作 news_id 参数 @router.post("/news/batch", response_model=BatchAnalyzeResponse) async def batch_analyze_news( request_body: BatchAnalyzeRequest, db: AsyncSession = Depends(get_db) ): """ 批量分析新闻(并发) - **news_ids**: 要分析的新闻ID列表 - **provider**: LLM提供商(可选) - **model**: 模型名称(可选) """ try: logger.info(f"Received batch analyze request: news_ids={request_body.news_ids}, provider={request_body.provider}, model={request_body.model}") if not request_body.news_ids: raise HTTPException(status_code=400, detail="news_ids cannot be empty") analysis_service = get_analysis_service() # 准备LLM provider参数 llm_provider = request_body.provider llm_model = request_body.model # 定义单个新闻的分析任务 # 注意:每个任务需要独立的数据库会话,因为SQLAlchemy异步会话不支持并发操作 async def analyze_single_news(news_id: int) -> AnalysisResponse: # 为每个任务创建独立的数据库会话 async with AsyncSessionLocal() as task_db: try: result = await analysis_service.analyze_news( news_id, task_db, llm_provider=llm_provider, llm_model=llm_model ) # 提交事务 await task_db.commit() if result.get("success"): return AnalysisResponse( success=True, analysis_id=result.get("analysis_id"), news_id=news_id, sentiment=result.get("sentiment"), sentiment_score=result.get("sentiment_score"), confidence=result.get("confidence"), summary=result.get("summary"), execution_time=result.get("execution_time"), ) else: return AnalysisResponse( success=False, news_id=news_id, error=result.get("error") ) except Exception as e: # 发生错误时回滚事务 await task_db.rollback() logger.error(f"Failed to analyze news {news_id}: {e}", exc_info=True) return AnalysisResponse( success=False, news_id=news_id, error=str(e) ) # 并发执行所有分析任务 logger.info(f"Starting batch analysis for {len(request_body.news_ids)} news items") results = await asyncio.gather(*[analyze_single_news(news_id) for news_id in request_body.news_ids]) # 统计结果 success_count = sum(1 for r in results if r.success) failed_count = len(results) - success_count logger.info(f"Batch analysis completed: {success_count} succeeded, {failed_count} failed") return BatchAnalyzeResponse( success=True, message=f"批量分析完成:成功 {success_count} 条,失败 {failed_count} 条", total_count=len(request_body.news_ids), success_count=success_count, failed_count=failed_count, results=results ) except HTTPException: raise except Exception as e: logger.error(f"Failed to batch analyze news: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/news/{news_id}", response_model=AnalysisResponse) async def analyze_news( news_id: int, request: Optional[AnalysisRequest] = Body(None), background_tasks: BackgroundTasks = None, db: AsyncSession = Depends(get_db) ): """ 触发新闻分析任务 - **news_id**: 新闻ID - **provider**: LLM提供商(可选) - **model**: 模型名称(可选) Returns: 分析任务状态 """ try: analysis_service = get_analysis_service() # 准备LLM provider参数 llm_provider = None llm_model = None if request: llm_provider = request.provider llm_model = request.model if llm_provider or llm_model: logger.info(f"Using custom LLM config: provider={llm_provider}, model={llm_model}") # 执行分析(同步,便于快速验证MVP) # 在生产环境中,应该使用后台任务 result = await analysis_service.analyze_news( news_id, db, llm_provider=llm_provider, llm_model=llm_model ) if result.get("success"): return AnalysisResponse( success=True, analysis_id=result.get("analysis_id"), news_id=news_id, sentiment=result.get("sentiment"), sentiment_score=result.get("sentiment_score"), confidence=result.get("confidence"), summary=result.get("summary"), execution_time=result.get("execution_time"), ) else: return AnalysisResponse( success=False, news_id=news_id, error=result.get("error") ) except Exception as e: logger.error(f"Failed to analyze news {news_id}: {e}", exc_info=True) return AnalysisResponse( success=False, news_id=news_id, error=str(e) ) @router.get("/news/{news_id}/all", response_model=List[AnalysisDetailResponse]) async def get_news_analyses( news_id: int, db: AsyncSession = Depends(get_db) ): """ 获取指定新闻的所有分析结果 - **news_id**: 新闻ID """ try: analysis_service = get_analysis_service() results = await analysis_service.get_analyses_by_news_id(news_id, db) return [AnalysisDetailResponse(**result) for result in results] except Exception as e: logger.error(f"Failed to get analyses for news {news_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{analysis_id}", response_model=AnalysisDetailResponse) async def get_analysis_detail( analysis_id: int, db: AsyncSession = Depends(get_db) ): """ 获取分析结果详情 - **analysis_id**: 分析ID """ try: analysis_service = get_analysis_service() result = await analysis_service.get_analysis_by_id(analysis_id, db) if not result: raise HTTPException(status_code=404, detail="Analysis not found") return AnalysisDetailResponse(**result) except HTTPException: raise except Exception as e: logger.error(f"Failed to get analysis {analysis_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) ================================================ FILE: backend/app/api/v1/debug.py ================================================ """ 调试 API - 用于测试爬虫和内容提取 """ import re import logging from typing import Optional from fastapi import APIRouter, HTTPException from pydantic import BaseModel import requests from bs4 import BeautifulSoup logger = logging.getLogger(__name__) router = APIRouter() class CrawlRequest(BaseModel): url: str return_html: bool = True # 是否返回原始 HTML class CrawlResponse(BaseModel): url: str title: Optional[str] = None content: Optional[str] = None content_length: int = 0 html_length: int = 0 raw_html: Optional[str] = None # 原始 HTML(可选) debug_info: dict = {} def extract_chinese_ratio(text: str) -> float: """计算中文字符比例""" pattern = re.compile(r'[\u4e00-\u9fa5]+') chinese_chars = pattern.findall(text) chinese_count = sum(len(chars) for chars in chinese_chars) total_count = len(text) return chinese_count / total_count if total_count > 0 else 0 def clean_text(text: str) -> str: """清理文本""" text = re.sub(r'<[^>]+>', '', text) text = text.replace('\u3000', ' ') text = ' '.join(text.split()) return text.strip() def is_noise_text(text: str) -> bool: """判断是否为噪音文本""" noise_patterns = [ r'^责任编辑', r'^编辑[::]', r'^来源[::]', r'^声明[::]', r'^免责声明', r'^版权', r'^copyright', r'^点击进入', r'^相关阅读', r'^延伸阅读', r'登录新浪财经APP', r'搜索【信披】', r'缩小字体', r'放大字体', r'收藏', r'微博', r'微信', r'分享', r'腾讯QQ', ] text_lower = text.lower().strip() for pattern in noise_patterns: if re.search(pattern, text_lower, re.I): return True return False def extract_content_from_html(html: str, url: str) -> tuple[str, str, dict]: """ 从 HTML 中提取内容 返回: (title, content, debug_info) """ soup = BeautifulSoup(html, 'lxml') debug_info = { "selectors_tried": [], "selector_matched": None, "total_lines_raw": 0, "lines_kept": 0, "lines_filtered": 0, } # 提取标题 title = "" title_tag = soup.find('h1', class_='main-title') or soup.find('h1') or soup.find('title') if title_tag: title = title_tag.get_text().strip() title = re.sub(r'[-_].*?(新浪|财经|网)', '', title).strip() # 内容选择器(按优先级) content_selectors = [ {'id': 'artibody'}, {'class': 'article-content'}, {'class': 'article'}, {'id': 'article'}, {'class': 'content'}, {'class': 'news-content'}, ] for selector in content_selectors: debug_info["selectors_tried"].append(str(selector)) content_div = soup.find(['div', 'article'], selector) if content_div: debug_info["selector_matched"] = str(selector) # 移除噪音元素 for tag in content_div.find_all(['script', 'style', 'iframe', 'ins', 'select', 'input', 'button', 'form']): tag.decompose() for ad in content_div.find_all(class_=re.compile(r'ad|banner|share|otherContent|recommend|app-guide', re.I)): ad.decompose() # 获取全文 full_text = content_div.get_text(separator='\n', strip=True) lines = full_text.split('\n') debug_info["total_lines_raw"] = len(lines) article_parts = [] for line in lines: line = line.strip() if not line or len(line) < 2: continue chinese_ratio = extract_chinese_ratio(line) if chinese_ratio > 0.05 or len(line) > 20: clean_line = clean_text(line) if clean_line and not is_noise_text(clean_line): article_parts.append(clean_line) debug_info["lines_kept"] += 1 else: debug_info["lines_filtered"] += 1 else: debug_info["lines_filtered"] += 1 content = '\n'.join(article_parts) return title, content, debug_info debug_info["selector_matched"] = "fallback (body)" # 后备:直接取 body body = soup.find('body') if body: content = body.get_text(separator='\n', strip=True) return title, content[:5000], debug_info # 限制长度 return title, "", debug_info @router.post("/crawl", response_model=CrawlResponse) async def debug_crawl(request: CrawlRequest): """ 实时爬取指定 URL 并返回内容(用于调试) - **url**: 要爬取的新闻 URL - **return_html**: 是否返回原始 HTML(默认 True) """ try: headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" } response = requests.get(request.url, headers=headers, timeout=30) response.encoding = 'utf-8' html = response.text title, content, debug_info = extract_content_from_html(html, request.url) return CrawlResponse( url=request.url, title=title, content=content, content_length=len(content), html_length=len(html), raw_html=html if request.return_html else None, debug_info=debug_info, ) except requests.RequestException as e: raise HTTPException(status_code=500, detail=f"爬取失败: {str(e)}") except Exception as e: logger.error(f"Debug crawl error: {e}") raise HTTPException(status_code=500, detail=f"解析失败: {str(e)}") @router.get("/test-sina") async def test_sina_crawl(): """ 测试新浪财经爬取(使用固定 URL) """ test_url = "https://finance.sina.com.cn/jjxw/2024-12-28/doc-ineayfsz5142013.shtml" request = CrawlRequest(url=test_url, return_html=False) return await debug_crawl(request) ================================================ FILE: backend/app/api/v1/knowledge_graph.py ================================================ """ 知识图谱管理 API 提供图谱的查询、构建、更新、删除接口 """ import logging from typing import List, Dict, Any, Optional from fastapi import APIRouter, HTTPException, BackgroundTasks from pydantic import BaseModel, Field logger = logging.getLogger(__name__) router = APIRouter() # ============ Pydantic 模型 ============ class CompanyGraphResponse(BaseModel): """公司图谱响应""" stock_code: str stock_name: str graph_exists: bool stats: Optional[Dict[str, int]] = None name_variants: List[str] = Field(default_factory=list) businesses: List[Dict[str, Any]] = Field(default_factory=list) industries: List[str] = Field(default_factory=list) products: List[str] = Field(default_factory=list) concepts: List[str] = Field(default_factory=list) search_queries: List[str] = Field(default_factory=list, description="生成的检索查询") class BuildGraphRequest(BaseModel): """构建图谱请求""" force_rebuild: bool = Field(default=False, description="是否强制重建") class BuildGraphResponse(BaseModel): """构建图谱响应""" success: bool message: str graph_stats: Optional[Dict[str, int]] = None class UpdateGraphRequest(BaseModel): """更新图谱请求""" update_from_news: bool = Field(default=True, description="是否从新闻更新") news_limit: int = Field(default=20, description="分析的新闻数量") class GraphStatsResponse(BaseModel): """图谱统计响应""" total_companies: int total_nodes: int total_relationships: int companies: List[Dict[str, str]] = Field(default_factory=list) # ============ API 路由 ============ @router.get("/{stock_code}", response_model=CompanyGraphResponse) async def get_company_graph(stock_code: str): """ 获取公司知识图谱 - **stock_code**: 股票代码 """ try: from ...knowledge.graph_service import get_graph_service # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" graph_service = get_graph_service() # 获取图谱 graph = graph_service.get_company_graph(code) if not graph: return CompanyGraphResponse( stock_code=code, stock_name=stock_code, graph_exists=False ) # 获取统计信息 stats = graph_service.get_graph_stats(code) # 获取检索关键词 keyword_set = graph_service.get_search_keywords(code) search_queries = keyword_set.combined_queries if keyword_set else [] return CompanyGraphResponse( stock_code=code, stock_name=graph.company.stock_name, graph_exists=True, stats=stats, name_variants=[v.variant for v in graph.name_variants], businesses=[ { "name": b.business_name, "type": b.business_type, "status": b.status, "description": b.description } for b in graph.businesses ], industries=[i.industry_name for i in graph.industries], products=[p.product_name for p in graph.products], concepts=[c.concept_name for c in graph.concepts], search_queries=search_queries ) except Exception as e: logger.error(f"Failed to get company graph for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/{stock_code}/build", response_model=BuildGraphResponse) async def build_company_graph( stock_code: str, request: BuildGraphRequest, background_tasks: BackgroundTasks ): """ 构建或重建公司知识图谱 - **stock_code**: 股票代码 - **force_rebuild**: 是否强制重建(删除现有图谱) """ try: from ...knowledge.graph_service import get_graph_service from ...knowledge.knowledge_extractor import ( create_knowledge_extractor, AkshareKnowledgeExtractor ) # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" graph_service = get_graph_service() # 检查是否已存在 existing = graph_service.get_company_graph(code) if existing and not request.force_rebuild: return BuildGraphResponse( success=False, message=f"图谱已存在,如需重建请设置 force_rebuild=true", graph_stats=graph_service.get_graph_stats(code) ) # 强制重建:先删除 if existing and request.force_rebuild: graph_service.delete_company_graph(code) logger.info(f"已删除现有图谱: {code}") # 从 akshare 获取信息 akshare_info = AkshareKnowledgeExtractor.extract_company_info(code) if not akshare_info: return BuildGraphResponse( success=False, message=f"无法从 akshare 获取公司信息: {code}" ) # 获取股票名称 stock_name = akshare_info.get('raw_data', {}).get('股票简称', code) # 使用 LLM 提取详细信息 extractor = create_knowledge_extractor() # 在后台任务中执行(避免阻塞) import asyncio graph = await extractor.extract_from_akshare(code, stock_name, akshare_info) # 构建图谱 success = graph_service.build_company_graph(graph) if success: stats = graph_service.get_graph_stats(code) return BuildGraphResponse( success=True, message=f"图谱构建成功: {stock_name}", graph_stats=stats ) else: return BuildGraphResponse( success=False, message="图谱构建失败" ) except Exception as e: logger.error(f"Failed to build graph for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/{stock_code}/update", response_model=BuildGraphResponse) async def update_company_graph( stock_code: str, request: UpdateGraphRequest ): """ 更新公司知识图谱 - **stock_code**: 股票代码 - **update_from_news**: 是否从新闻更新 - **news_limit**: 分析的新闻数量 """ try: from ...knowledge.graph_service import get_graph_service from ...knowledge.knowledge_extractor import create_knowledge_extractor from ...core.database import get_db from sqlalchemy.ext.asyncio import AsyncSession from ...models.news import News from sqlalchemy import select, text # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" pure_code = code[2:] if code.startswith(("SH", "SZ")) else code graph_service = get_graph_service() # 检查图谱是否存在 if not graph_service.get_company_graph(code): return BuildGraphResponse( success=False, message="图谱不存在,请先构建图谱" ) if request.update_from_news: # 从数据库获取最新新闻 from ...core.database import get_sync_db_session db = get_sync_db_session() recent_news = db.execute( text(""" SELECT title, content FROM news WHERE stock_codes @> ARRAY[:code]::varchar[] ORDER BY publish_time DESC LIMIT :limit """).bindparams(code=pure_code, limit=request.news_limit) ).fetchall() if not recent_news: return BuildGraphResponse( success=False, message="没有可用的新闻数据" ) news_data = [ {"title": n[0], "content": n[1]} for n in recent_news ] # 提取信息 extractor = create_knowledge_extractor() extracted_info = await extractor.extract_from_news(code, "", news_data) # 更新图谱 if any(extracted_info.values()): success = graph_service.update_from_news(code, "", extracted_info) if success: stats = graph_service.get_graph_stats(code) return BuildGraphResponse( success=True, message=f"图谱已更新: 新增业务{len(extracted_info.get('new_businesses', []))}个, 概念{len(extracted_info.get('new_concepts', []))}个", graph_stats=stats ) else: return BuildGraphResponse( success=False, message="图谱更新失败" ) else: return BuildGraphResponse( success=True, message="未提取到新信息", graph_stats=graph_service.get_graph_stats(code) ) return BuildGraphResponse( success=False, message="未指定更新方式" ) except Exception as e: logger.error(f"Failed to update graph for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/{stock_code}") async def delete_company_graph(stock_code: str): """ 删除公司知识图谱 - **stock_code**: 股票代码 """ try: from ...knowledge.graph_service import get_graph_service # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" graph_service = get_graph_service() success = graph_service.delete_company_graph(code) if success: return {"success": True, "message": f"图谱已删除: {code}"} else: return {"success": False, "message": "删除失败"} except Exception as e: logger.error(f"Failed to delete graph for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/", response_model=GraphStatsResponse) async def get_graph_stats(): """ 获取所有图谱统计信息 """ try: from ...knowledge.graph_service import get_graph_service graph_service = get_graph_service() companies = graph_service.list_all_companies() # 获取总体统计 total_companies = len(companies) # 查询总节点数和关系数(简化版) return GraphStatsResponse( total_companies=total_companies, total_nodes=total_companies * 10, # 估算 total_relationships=total_companies * 15, # 估算 companies=companies ) except Exception as e: logger.error(f"Failed to get graph stats: {e}") raise HTTPException(status_code=500, detail=str(e)) ================================================ FILE: backend/app/api/v1/llm_config.py ================================================ """ LLM 配置 API 路由 返回可用的 LLM 厂商和模型列表 """ import logging from typing import List, Dict, Optional from fastapi import APIRouter from pydantic import BaseModel, Field from ...core.config import settings logger = logging.getLogger(__name__) router = APIRouter() class ModelInfo(BaseModel): """模型信息""" value: str = Field(..., description="模型标识") label: str = Field(..., description="模型显示名称") description: str = Field(default="", description="模型描述") class ProviderInfo(BaseModel): """厂商信息""" value: str = Field(..., description="厂商标识") label: str = Field(..., description="厂商显示名称") icon: str = Field(..., description="厂商图标") models: List[ModelInfo] = Field(..., description="可用模型列表") has_api_key: bool = Field(..., description="是否已配置API Key") class LLMConfigResponse(BaseModel): """LLM 配置响应""" default_provider: str = Field(..., description="默认厂商") default_model: str = Field(..., description="默认模型") providers: List[ProviderInfo] = Field(..., description="可用厂商列表") def parse_models(models_str: str, provider_label: str) -> List[ModelInfo]: """ 解析逗号分隔的模型字符串 Args: models_str: 逗号分隔的模型字符串 provider_label: 厂商显示名称 Returns: 模型信息列表 """ if not models_str: return [] models = [] for model in models_str.split(','): model = model.strip() if model: models.append(ModelInfo( value=model, label=model, description=f"{provider_label} 模型" )) return models @router.get("/config", response_model=LLMConfigResponse) async def get_llm_config(): """ 获取 LLM 配置信息 返回所有可用的厂商和模型列表,以及是否已配置 API Key """ try: providers = [] # 1. 百炼 if settings.BAILIAN_MODELS: providers.append(ProviderInfo( value="bailian", label="百炼(阿里云)", icon="📦", models=parse_models(settings.BAILIAN_MODELS, "百炼"), has_api_key=bool(settings.DASHSCOPE_API_KEY or settings.BAILIAN_API_KEY) )) # 2. OpenAI if settings.OPENAI_MODELS: providers.append(ProviderInfo( value="openai", label="OpenAI", icon="🤖", models=parse_models(settings.OPENAI_MODELS, "OpenAI"), has_api_key=bool(settings.OPENAI_API_KEY) )) # 3. DeepSeek if settings.DEEPSEEK_MODELS: providers.append(ProviderInfo( value="deepseek", label="DeepSeek", icon="🧠", models=parse_models(settings.DEEPSEEK_MODELS, "DeepSeek"), has_api_key=bool(settings.DEEPSEEK_API_KEY) )) # 4. Kimi if settings.MOONSHOT_MODELS: providers.append(ProviderInfo( value="kimi", label="Kimi (Moonshot)", icon="🌙", models=parse_models(settings.MOONSHOT_MODELS, "Kimi"), has_api_key=bool(settings.MOONSHOT_API_KEY) )) # 5. 智谱 if settings.ZHIPU_MODELS: providers.append(ProviderInfo( value="zhipu", label="智谱", icon="🔮", models=parse_models(settings.ZHIPU_MODELS, "智谱"), has_api_key=bool(settings.ZHIPU_API_KEY) )) return LLMConfigResponse( default_provider=settings.LLM_PROVIDER, default_model=settings.LLM_MODEL, providers=providers ) except Exception as e: logger.error(f"Failed to get LLM config: {e}", exc_info=True) # 返回默认配置 return LLMConfigResponse( default_provider="bailian", default_model="qwen-plus", providers=[] ) ================================================ FILE: backend/app/api/v1/news.py ================================================ """ 新闻管理 API 路由 """ import logging from typing import List, Optional from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc from ...core.database import get_db from ...models.news import News from ...tools import SinaCrawlerTool logger = logging.getLogger(__name__) router = APIRouter() # Pydantic 模型 class NewsResponse(BaseModel): """新闻响应模型""" model_config = {"from_attributes": True} id: int title: str content: str url: str source: str publish_time: Optional[str] = None stock_codes: Optional[List[str]] = None sentiment_score: Optional[float] = None created_at: str class CrawlRequest(BaseModel): """爬取请求模型""" source: str = Field(default="sina", description="新闻源(sina, jrj, cnstock)") start_page: int = Field(default=1, ge=1, description="起始页码") end_page: int = Field(default=1, ge=1, le=10, description="结束页码") class CrawlResponse(BaseModel): """爬取响应模型""" success: bool message: str crawled_count: int saved_count: int source: str class BatchDeleteRequest(BaseModel): """批量删除请求模型""" news_ids: List[int] = Field(..., description="要删除的新闻ID列表") class BatchDeleteResponse(BaseModel): """批量删除响应模型""" success: bool message: str deleted_count: int # 后台任务:爬取并保存新闻(使用同步方式) def crawl_and_save_news_sync( source: str, start_page: int, end_page: int ): """ 后台任务:爬取新闻并保存到数据库(同步版本) """ from sqlalchemy import create_engine from sqlalchemy.orm import Session from ...core.config import settings try: logger.info(f"Starting crawl task: {source}, pages {start_page}-{end_page}") # 创建爬虫 if source == "sina": crawler = SinaCrawlerTool() else: logger.error(f"Unsupported source: {source}") return # 执行爬取 news_list = crawler.crawl(start_page, end_page) logger.info(f"Crawled {len(news_list)} news items") # 创建新的数据库连接(同步) engine = create_engine(settings.SYNC_DATABASE_URL) db = Session(engine) try: # 时间过滤:只保存最近7天内的新闻(避免保存太旧的新闻) cutoff_time = datetime.utcnow() - timedelta(days=7) # 保存到数据库 saved_count = 0 skipped_old_count = 0 skipped_existing_count = 0 for news_item in news_list: # 时间过滤:跳过太旧的新闻 if news_item.publish_time and news_item.publish_time < cutoff_time: skipped_old_count += 1 logger.debug(f"Skipping old news: {news_item.title[:50]} (published: {news_item.publish_time})") continue # 检查URL是否已存在 existing = db.execute( select(News).where(News.url == news_item.url) ).scalar_one_or_none() if existing: skipped_existing_count += 1 logger.debug(f"News already exists: {news_item.url}") continue # 创建新记录 news = News( title=news_item.title, content=news_item.content, url=news_item.url, source=news_item.source, publish_time=news_item.publish_time, author=news_item.author, keywords=news_item.keywords, stock_codes=news_item.stock_codes, # summary 字段已移除,content 包含完整内容 ) db.add(news) saved_count += 1 logger.info(f"Saved new news: {news_item.title[:50]} (published: {news_item.publish_time})") db.commit() logger.info( f"Crawl summary: crawled={len(news_list)}, " f"saved={saved_count}, " f"skipped_old={skipped_old_count}, " f"skipped_existing={skipped_existing_count}" ) finally: db.close() except Exception as e: logger.error(f"Crawl task failed: {e}", exc_info=True) # API 端点 @router.post("/crawl", response_model=CrawlResponse) async def crawl_news( request: CrawlRequest, background_tasks: BackgroundTasks ): """ 触发新闻爬取任务(异步后台任务) - **source**: 新闻源(sina, jrj, cnstock) - **start_page**: 起始页码 - **end_page**: 结束页码 注意:这是简单的后台任务版本。如需更强大的任务管理, 请使用 POST /api/v1/tasks/cold-start 触发 Celery 任务。 """ # 添加到后台任务(同步版本) background_tasks.add_task( crawl_and_save_news_sync, request.source, request.start_page, request.end_page ) logger.info(f"Background crawl task added: {request.source}, pages {request.start_page}-{request.end_page}") return CrawlResponse( success=True, message=f"Crawl task started for {request.source}, pages {request.start_page}-{request.end_page}", crawled_count=0, # 后台任务还未完成 saved_count=0, source=request.source ) @router.post("/refresh", response_model=CrawlResponse) async def refresh_news( source: str = Query("sina", description="新闻源"), pages: int = Query(1, ge=1, le=5, description="爬取页数"), background_tasks: BackgroundTasks = None ): """ 刷新新闻(前端刷新按钮调用) - **source**: 新闻源(sina, tencent, nbd, eastmoney, yicai, 163) - **pages**: 爬取页数(1-5) """ background_tasks.add_task( crawl_and_save_news_sync, source, 1, # start_page pages # end_page ) logger.info(f"Refresh task started: {source}, {pages} pages") return CrawlResponse( success=True, message=f"刷新任务已启动:{source},{pages} 页", crawled_count=0, saved_count=0, source=source ) @router.get("/", response_model=List[NewsResponse]) async def get_news_list( skip: int = Query(0, ge=0, description="跳过的记录数"), limit: int = Query(20, ge=1, le=100, description="返回的记录数"), source: Optional[str] = Query(None, description="按来源筛选"), db: AsyncSession = Depends(get_db) ): """ 获取新闻列表 - **skip**: 跳过的记录数(分页) - **limit**: 返回的记录数 - **source**: 按来源筛选(可选) """ try: query = select(News).order_by(desc(News.created_at)) if source: query = query.where(News.source == source) query = query.offset(skip).limit(limit) result = await db.execute(query) news_list = result.scalars().all() return [NewsResponse(**news.to_dict()) for news in news_list] except Exception as e: logger.error(f"Failed to get news list: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/latest", response_model=List[NewsResponse]) async def get_latest_news( limit: int = Query(20, ge=1, le=500, description="返回的记录数"), source: Optional[str] = Query(None, description="按来源筛选"), db: AsyncSession = Depends(get_db) ): """ 获取最新新闻(按发布时间排序) - **limit**: 返回的记录数(最多500条) - **source**: 按来源筛选(可选) """ try: query = select(News).order_by(desc(News.publish_time)) if source: query = query.where(News.source == source) query = query.limit(limit) result = await db.execute(query) news_list = result.scalars().all() return [NewsResponse(**news.to_dict()) for news in news_list] except Exception as e: logger.error(f"Failed to get latest news: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{news_id}", response_model=NewsResponse) async def get_news_detail( news_id: int, db: AsyncSession = Depends(get_db) ): """ 获取新闻详情 - **news_id**: 新闻ID """ try: result = await db.execute( select(News).where(News.id == news_id) ) news = result.scalar_one_or_none() if not news: raise HTTPException(status_code=404, detail="News not found") return NewsResponse(**news.to_dict()) except HTTPException: raise except Exception as e: logger.error(f"Failed to get news {news_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/batch/delete", response_model=BatchDeleteResponse) async def batch_delete_news( request: BatchDeleteRequest, db: AsyncSession = Depends(get_db) ): """ 批量删除新闻 - **news_ids**: 要删除的新闻ID列表 """ try: if not request.news_ids: raise HTTPException(status_code=400, detail="news_ids cannot be empty") # 查询要删除的新闻 result = await db.execute( select(News).where(News.id.in_(request.news_ids)) ) news_list = result.scalars().all() deleted_count = len(news_list) if deleted_count == 0: return BatchDeleteResponse( success=True, message="No news found to delete", deleted_count=0 ) # 批量删除 for news in news_list: await db.delete(news) await db.commit() logger.info(f"Batch deleted {deleted_count} news items: {request.news_ids}") return BatchDeleteResponse( success=True, message=f"Successfully deleted {deleted_count} news items", deleted_count=deleted_count ) except HTTPException: raise except Exception as e: logger.error(f"Failed to batch delete news: {e}") await db.rollback() raise HTTPException(status_code=500, detail=str(e)) @router.delete("/{news_id}") async def delete_news( news_id: int, db: AsyncSession = Depends(get_db) ): """ 删除新闻 - **news_id**: 新闻ID """ try: result = await db.execute( select(News).where(News.id == news_id) ) news = result.scalar_one_or_none() if not news: raise HTTPException(status_code=404, detail="News not found") await db.delete(news) await db.commit() return {"success": True, "message": f"News {news_id} deleted"} except HTTPException: raise except Exception as e: logger.error(f"Failed to delete news {news_id}: {e}") await db.rollback() raise HTTPException(status_code=500, detail=str(e)) ================================================ FILE: backend/app/api/v1/news_v2.py ================================================ """ 新闻 API v2 - 使用新的 Financial Data Layer 新功能: 1. 多数据源支持:可指定 provider (sina, tencent, nbd...) 2. 自动降级:一个源失败自动切换另一个 3. 标准化数据:统一的 NewsData 格式 4. 实时获取:直接从数据源获取,不经过数据库 前端可通过对比 /api/v1/news (旧) vs /api/v1/news/v2 (新) 看到差异 """ import logging from typing import List, Optional from datetime import datetime from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel, Field from ...financial import get_registry, NewsQueryParams from ...financial.tools import FinancialNewsTool, setup_default_providers logger = logging.getLogger(__name__) router = APIRouter() # 确保 Provider 已注册 setup_default_providers() class NewsDataResponse(BaseModel): """标准化新闻响应(使用 NewsData 模型)""" id: str title: str content: str summary: Optional[str] = None source: str source_url: str publish_time: datetime stock_codes: List[str] = [] sentiment: Optional[str] = None sentiment_score: Optional[float] = None class FetchNewsResponse(BaseModel): """获取新闻响应""" success: bool count: int provider: Optional[str] = None available_providers: Optional[List[str]] = None data: List[NewsDataResponse] = [] error: Optional[str] = None class ProviderInfoResponse(BaseModel): """Provider 信息响应""" name: str display_name: str description: str supported_types: List[str] priority: int @router.get("/fetch", response_model=FetchNewsResponse) async def fetch_news_realtime( stock_codes: Optional[str] = Query( None, description="股票代码,多个用逗号分隔,如 '600519,000001'" ), keywords: Optional[str] = Query( None, description="关键词,多个用逗号分隔" ), limit: int = Query( 20, ge=1, le=100, description="返回条数" ), provider: Optional[str] = Query( None, description="指定数据源(sina, tencent, nbd),不指定则自动选择" ) ): """ 实时获取新闻(使用新的 Provider-Fetcher 架构) 特点: - 直接从数据源获取,不经过数据库 - 支持指定数据源或自动选择 - 返回标准化的 NewsData 格式 示例: - GET /api/v1/news/v2/fetch?stock_codes=600519&limit=10 - GET /api/v1/news/v2/fetch?keywords=茅台,白酒&provider=sina """ tool = FinancialNewsTool() # 解析参数 stock_code_list = stock_codes.split(",") if stock_codes else None keyword_list = keywords.split(",") if keywords else None try: result = await tool.aexecute( stock_codes=stock_code_list, keywords=keyword_list, limit=limit, provider=provider ) if result["success"]: # 转换为响应格式 news_list = [ NewsDataResponse( id=item["id"], title=item["title"], content=item["content"], summary=item.get("summary"), source=item["source"], source_url=item["source_url"], publish_time=item["publish_time"], stock_codes=item.get("stock_codes", []), sentiment=item.get("sentiment"), sentiment_score=item.get("sentiment_score") ) for item in result["data"] ] return FetchNewsResponse( success=True, count=result["count"], provider=result.get("provider"), data=news_list ) else: return FetchNewsResponse( success=False, count=0, error=result.get("error"), available_providers=result.get("available_providers", []) ) except Exception as e: logger.exception(f"Failed to fetch news: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/providers", response_model=List[ProviderInfoResponse]) async def list_providers(): """ 列出所有可用的数据源 Provider 返回: - 每个 Provider 的名称、描述、支持的数据类型、优先级 """ registry = get_registry() providers = [] for name in registry.list_providers(): provider = registry.get_provider(name) if provider: providers.append(ProviderInfoResponse( name=provider.info.name, display_name=provider.info.display_name, description=provider.info.description, supported_types=list(provider.fetchers.keys()), priority=provider.info.priority )) return providers @router.get("/providers/{provider_name}/test") async def test_provider( provider_name: str, limit: int = Query(5, ge=1, le=20) ): """ 测试指定的 Provider 是否工作正常 返回: - 测试结果和获取到的样本数据 """ tool = FinancialNewsTool() try: result = await tool.aexecute( limit=limit, provider=provider_name ) return { "provider": provider_name, "success": result["success"], "count": result.get("count", 0), "error": result.get("error"), "sample_titles": [ item["title"][:50] for item in result.get("data", [])[:3] ] } except Exception as e: return { "provider": provider_name, "success": False, "error": str(e) } ================================================ FILE: backend/app/api/v1/stocks.py ================================================ """ 股票分析 API 路由 - Phase 2 提供个股分析、关联新闻、情感趋势等接口 支持 akshare 真实股票数据 """ import logging from datetime import datetime, timedelta from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, and_, desc, text, or_ from sqlalchemy.dialects.postgresql import ARRAY, array from ...core.database import get_db from ...models.news import News from ...models.stock import Stock from ...models.analysis import Analysis from ...models.crawl_task import CrawlTask, CrawlMode, TaskStatus from ...services.stock_data_service import stock_data_service from ...tasks.crawl_tasks import targeted_stock_crawl_task logger = logging.getLogger(__name__) router = APIRouter() # ============ Pydantic 模型 ============ class StockInfo(BaseModel): """股票信息""" model_config = {"from_attributes": True} code: str name: str full_code: Optional[str] = None industry: Optional[str] = None market: Optional[str] = None pe_ratio: Optional[float] = None market_cap: Optional[float] = None class StockNewsItem(BaseModel): """股票关联新闻""" id: int title: str content: str url: str source: str publish_time: Optional[str] = None sentiment_score: Optional[float] = None has_analysis: bool = False class SentimentTrendPoint(BaseModel): """情感趋势数据点""" date: str avg_sentiment: float news_count: int positive_count: int negative_count: int neutral_count: int class StockOverview(BaseModel): """股票概览数据""" code: str name: Optional[str] = None total_news: int analyzed_news: int avg_sentiment: Optional[float] = None recent_sentiment: Optional[float] = None # 最近7天 sentiment_trend: str # "up", "down", "stable" last_news_time: Optional[str] = None class KLineDataPoint(BaseModel): """K线数据点(akshare 真实数据)""" timestamp: int # 时间戳(毫秒) date: str open: float high: float low: float close: float volume: int turnover: Optional[float] = None # 成交额 change_percent: Optional[float] = None # 涨跌幅 change_amount: Optional[float] = None # 涨跌额 amplitude: Optional[float] = None # 振幅 turnover_rate: Optional[float] = None # 换手率 # ============ API 端点 ============ # ⚠️ 注意:具体路径的路由必须放在动态路由 /{stock_code} 之前! class StockSearchResult(BaseModel): """股票搜索结果""" code: str name: str full_code: str market: Optional[str] = None industry: Optional[str] = None @router.get("/search/realtime", response_model=List[StockSearchResult]) async def search_stocks_realtime( q: str = Query(..., min_length=1, description="搜索关键词(代码或名称)"), limit: int = Query(20, le=50), db: AsyncSession = Depends(get_db) ): """ 搜索股票(从数据库,支持代码和名称模糊匹配) - **q**: 搜索关键词(如 "600519" 或 "茅台") - **limit**: 返回数量限制 """ try: # 从数据库搜索 query = select(Stock).where( (Stock.code.ilike(f"%{q}%")) | (Stock.name.ilike(f"%{q}%")) | (Stock.full_code.ilike(f"%{q}%")) ).limit(limit) result = await db.execute(query) stocks = result.scalars().all() if stocks: return [ StockSearchResult( code=stock.code, name=stock.name, full_code=stock.full_code or f"{'SH' if stock.code.startswith('6') else 'SZ'}{stock.code}", market=stock.market, industry=stock.industry, ) for stock in stocks ] return [] except Exception as e: logger.error(f"Failed to search stocks: {e}") raise HTTPException(status_code=500, detail=str(e)) class StockInitResponse(BaseModel): """股票数据初始化响应""" success: bool message: str count: int = 0 @router.post("/init", response_model=StockInitResponse) async def init_stock_data( db: AsyncSession = Depends(get_db) ): """ 初始化股票数据(从 akshare 获取全部 A 股并存入数据库) """ try: import akshare as ak from datetime import datetime from sqlalchemy import delete logger.info("Starting stock data initialization...") df = ak.stock_zh_a_spot_em() if df is None or df.empty: return StockInitResponse(success=False, message="Failed to fetch stocks from akshare", count=0) await db.execute(delete(Stock)) count = 0 for _, row in df.iterrows(): code = str(row['代码']) name = str(row['名称']) if not code or not name or name in ['N/A', 'nan', '']: continue if code.startswith('6'): market = "SH" full_code = f"SH{code}" elif code.startswith('0') or code.startswith('3'): market = "SZ" full_code = f"SZ{code}" else: market = "OTHER" full_code = code stock = Stock( code=code, name=name, full_code=full_code, market=market, status="active", created_at=datetime.utcnow(), updated_at=datetime.utcnow(), ) db.add(stock) count += 1 await db.commit() return StockInitResponse(success=True, message=f"Successfully initialized {count} stocks", count=count) except ImportError: return StockInitResponse(success=False, message="akshare not installed", count=0) except Exception as e: logger.error(f"Failed to init stocks: {e}") await db.rollback() raise HTTPException(status_code=500, detail=str(e)) @router.get("/count") async def get_stock_count(db: AsyncSession = Depends(get_db)): """获取数据库中的股票数量""" from sqlalchemy import func as sql_func result = await db.execute(select(sql_func.count(Stock.id))) count = result.scalar() or 0 return {"count": count, "message": f"Database has {count} stocks"} # ============ 动态路由(必须放在最后) ============ @router.get("/{stock_code}", response_model=StockOverview) async def get_stock_overview( stock_code: str, db: AsyncSession = Depends(get_db) ): """ 获取股票概览信息 - **stock_code**: 股票代码(如 SH600519, 600519) """ # 标准化股票代码(支持带前缀和不带前缀) code = stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): short_code = code[2:] else: short_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" try: # 查询股票基本信息 stock_query = select(Stock).where( (Stock.code == short_code) | (Stock.full_code == code) ) result = await db.execute(stock_query) stock = result.scalar_one_or_none() stock_name = stock.name if stock else None # 统计关联新闻 # 使用 PostgreSQL 原生 ARRAY 查询语法 stock_codes_filter = text( "stock_codes @> ARRAY[:code1]::varchar[] OR stock_codes @> ARRAY[:code2]::varchar[]" ).bindparams(code1=short_code, code2=code) news_query = select(func.count(News.id)).where(stock_codes_filter) result = await db.execute(news_query) total_news = result.scalar() or 0 # 已分析的新闻数量 analyzed_query = select(func.count(News.id)).where( and_( stock_codes_filter, News.sentiment_score.isnot(None) ) ) result = await db.execute(analyzed_query) analyzed_news = result.scalar() or 0 # 计算平均情感 avg_sentiment_query = select(func.avg(News.sentiment_score)).where( and_( stock_codes_filter, News.sentiment_score.isnot(None) ) ) result = await db.execute(avg_sentiment_query) avg_sentiment = result.scalar() # 最近7天的平均情感 seven_days_ago = datetime.utcnow() - timedelta(days=7) recent_query = select(func.avg(News.sentiment_score)).where( and_( stock_codes_filter, News.sentiment_score.isnot(None), News.publish_time >= seven_days_ago ) ) result = await db.execute(recent_query) recent_sentiment = result.scalar() # 判断趋势 if avg_sentiment is not None and recent_sentiment is not None: diff = recent_sentiment - avg_sentiment if diff > 0.1: sentiment_trend = "up" elif diff < -0.1: sentiment_trend = "down" else: sentiment_trend = "stable" else: sentiment_trend = "stable" # 最新新闻时间 last_news_query = select(News.publish_time).where( stock_codes_filter ).order_by(desc(News.publish_time)).limit(1) result = await db.execute(last_news_query) last_news_time = result.scalar() return StockOverview( code=code, name=stock_name, total_news=total_news, analyzed_news=analyzed_news, avg_sentiment=round(avg_sentiment, 3) if avg_sentiment else None, recent_sentiment=round(recent_sentiment, 3) if recent_sentiment else None, sentiment_trend=sentiment_trend, last_news_time=last_news_time.isoformat() if last_news_time else None ) except Exception as e: logger.error(f"Failed to get stock overview for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{stock_code}/news", response_model=List[StockNewsItem]) async def get_stock_news( stock_code: str, limit: int = Query(50, le=200), offset: int = Query(0, ge=0), sentiment: Optional[str] = Query(None, description="筛选情感: positive, negative, neutral"), db: AsyncSession = Depends(get_db) ): """ 获取股票关联新闻列表 - **stock_code**: 股票代码 - **limit**: 返回数量限制 - **offset**: 偏移量 - **sentiment**: 情感筛选 """ # 标准化股票代码 code = stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): short_code = code[2:] else: short_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" try: # 构建查询 - 使用 PostgreSQL 原生 ARRAY 查询语法 stock_codes_filter = text( "stock_codes @> ARRAY[:code1]::varchar[] OR stock_codes @> ARRAY[:code2]::varchar[]" ).bindparams(code1=short_code, code2=code) query = select(News).where(stock_codes_filter) # 情感筛选 if sentiment: if sentiment == "positive": query = query.where(News.sentiment_score > 0.1) elif sentiment == "negative": query = query.where(News.sentiment_score < -0.1) elif sentiment == "neutral": query = query.where( and_( News.sentiment_score >= -0.1, News.sentiment_score <= 0.1 ) ) # 排序和分页 query = query.order_by(desc(News.publish_time)).offset(offset).limit(limit) result = await db.execute(query) news_list = result.scalars().all() # 检查每条新闻是否有分析 response = [] for news in news_list: # 检查是否有分析记录 analysis_query = select(func.count(Analysis.id)).where(Analysis.news_id == news.id) analysis_result = await db.execute(analysis_query) has_analysis = (analysis_result.scalar() or 0) > 0 response.append(StockNewsItem( id=news.id, title=news.title, content=news.content[:500] + "..." if len(news.content) > 500 else news.content, url=news.url, source=news.source, publish_time=news.publish_time.isoformat() if news.publish_time else None, sentiment_score=news.sentiment_score, has_analysis=has_analysis )) return response except Exception as e: logger.error(f"Failed to get news for stock {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/{stock_code}/news") async def delete_stock_news( stock_code: str, db: AsyncSession = Depends(get_db) ): """ 清除股票的所有关联新闻 - **stock_code**: 股票代码 """ # 标准化股票代码 code = stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): short_code = code[2:] else: short_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" try: # 构建查询 - 使用 PostgreSQL 原生 ARRAY 查询语法 stock_codes_filter = text( "stock_codes @> ARRAY[:code1]::varchar[] OR stock_codes @> ARRAY[:code2]::varchar[]" ).bindparams(code1=short_code, code2=code) # 先查询要删除的新闻ID列表(用于同时删除关联的分析记录) news_query = select(News.id).where(stock_codes_filter) news_result = await db.execute(news_query) news_ids = [row[0] for row in news_result.all()] deleted_count = len(news_ids) if deleted_count > 0: # 删除关联的分析记录 analysis_delete = await db.execute( text("DELETE FROM analyses WHERE news_id = ANY(:news_ids)").bindparams(news_ids=news_ids) ) logger.info(f"Deleted {analysis_delete.rowcount} analysis records for stock {stock_code}") # 删除新闻记录 news_delete = await db.execute( text("DELETE FROM news WHERE id = ANY(:news_ids)").bindparams(news_ids=news_ids) ) await db.commit() logger.info(f"Deleted {deleted_count} news for stock {stock_code}") return { "success": True, "message": f"已清除 {deleted_count} 条新闻", "deleted_count": deleted_count } except Exception as e: await db.rollback() logger.error(f"Failed to delete news for stock {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{stock_code}/sentiment-trend", response_model=List[SentimentTrendPoint]) async def get_sentiment_trend( stock_code: str, days: int = Query(30, le=90, ge=7, description="天数范围"), db: AsyncSession = Depends(get_db) ): """ 获取股票情感趋势(按天聚合) - **stock_code**: 股票代码 - **days**: 查询天数范围(7-90天) """ # 标准化股票代码 code = stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): short_code = code[2:] else: short_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" try: start_date = datetime.utcnow() - timedelta(days=days) # 按天聚合情感数据 # 使用原生 SQL 进行日期聚合 from sqlalchemy import text query = text(""" SELECT DATE(publish_time) as date, AVG(sentiment_score) as avg_sentiment, COUNT(*) as news_count, SUM(CASE WHEN sentiment_score > 0.1 THEN 1 ELSE 0 END) as positive_count, SUM(CASE WHEN sentiment_score < -0.1 THEN 1 ELSE 0 END) as negative_count, SUM(CASE WHEN sentiment_score >= -0.1 AND sentiment_score <= 0.1 THEN 1 ELSE 0 END) as neutral_count FROM news WHERE ( :short_code = ANY(stock_codes) OR :full_code = ANY(stock_codes) ) AND publish_time >= :start_date AND sentiment_score IS NOT NULL GROUP BY DATE(publish_time) ORDER BY date ASC """) result = await db.execute(query, { "short_code": short_code, "full_code": code, "start_date": start_date }) rows = result.fetchall() trend_data = [] for row in rows: trend_data.append(SentimentTrendPoint( date=row.date.isoformat() if row.date else "", avg_sentiment=round(row.avg_sentiment, 3) if row.avg_sentiment else 0, news_count=row.news_count or 0, positive_count=row.positive_count or 0, negative_count=row.negative_count or 0, neutral_count=row.neutral_count or 0 )) return trend_data except Exception as e: logger.error(f"Failed to get sentiment trend for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{stock_code}/kline", response_model=List[KLineDataPoint]) async def get_kline_data( stock_code: str, period: str = Query("daily", description="周期: daily, 1m, 5m, 15m, 30m, 60m"), limit: int = Query(90, le=500, ge=10, description="数据条数"), adjust: str = Query("qfq", description="复权类型: qfq=前复权, hfq=后复权, 空=不复权(仅日线有效)"), db: AsyncSession = Depends(get_db) ): """ 获取K线数据(真实数据,使用 akshare) - **stock_code**: 股票代码(支持 600519, SH600519, sh600519 等格式) - **period**: 周期类型 - daily: 日线(默认) - 1m: 1分钟 - 5m: 5分钟 - 15m: 15分钟 - 30m: 30分钟 - 60m: 60分钟/1小时 - **limit**: 返回数据条数(10-500,默认90) - **adjust**: 复权类型 (qfq=前复权, hfq=后复权, ""=不复权),仅对日线有效 """ try: kline_data = await stock_data_service.get_kline_data( stock_code=stock_code, period=period, limit=limit, adjust=adjust ) if not kline_data: logger.warning(f"No kline data for {stock_code} period={period}") return [] return [KLineDataPoint(**item) for item in kline_data] except Exception as e: logger.error(f"Failed to get kline data for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) class RealtimeQuote(BaseModel): """实时行情""" code: str name: str price: float change_percent: float change_amount: float volume: int turnover: float high: float low: float open: float prev_close: float @router.get("/{stock_code}/realtime", response_model=Optional[RealtimeQuote]) async def get_realtime_quote( stock_code: str, db: AsyncSession = Depends(get_db) ): """ 获取实时行情(使用 akshare) - **stock_code**: 股票代码 """ try: quote = await stock_data_service.get_realtime_quote(stock_code) if quote: return RealtimeQuote(**quote) return None except Exception as e: logger.error(f"Failed to get realtime quote for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/search/code", response_model=List[StockInfo]) async def search_stocks_db( q: str = Query(..., min_length=1, description="搜索关键词"), limit: int = Query(10, le=50), db: AsyncSession = Depends(get_db) ): """ 从数据库搜索股票 - **q**: 搜索关键词(代码或名称) """ try: query = select(Stock).where( (Stock.code.ilike(f"%{q}%")) | (Stock.name.ilike(f"%{q}%")) | (Stock.full_code.ilike(f"%{q}%")) ).limit(limit) result = await db.execute(query) stocks = result.scalars().all() return [StockInfo.model_validate(stock) for stock in stocks] except Exception as e: logger.error(f"Failed to search stocks: {e}") raise HTTPException(status_code=500, detail=str(e)) # ============ 定向爬取 API ============ class TargetedCrawlRequest(BaseModel): """定向爬取请求""" stock_name: str = Field(..., description="股票名称") days: int = Field(default=30, ge=1, le=90, description="搜索时间范围(天)") class TargetedCrawlResponse(BaseModel): """定向爬取响应""" success: bool message: str task_id: Optional[int] = None celery_task_id: Optional[str] = None class TargetedCrawlStatus(BaseModel): """定向爬取状态""" task_id: Optional[int] = None status: str # idle, pending, running, completed, failed celery_task_id: Optional[str] = None progress: Optional[dict] = None crawled_count: Optional[int] = None saved_count: Optional[int] = None error_message: Optional[str] = None execution_time: Optional[float] = None started_at: Optional[str] = None completed_at: Optional[str] = None @router.post("/{stock_code}/targeted-crawl", response_model=TargetedCrawlResponse) async def start_targeted_crawl( stock_code: str, request: TargetedCrawlRequest, db: AsyncSession = Depends(get_db) ): """ 触发定向爬取任务 - **stock_code**: 股票代码(如 SH600519) - **stock_name**: 股票名称(如 贵州茅台) - **days**: 搜索时间范围(默认30天) """ try: # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" # 检查是否有正在运行的任务 running_task = await db.execute( select(CrawlTask).where( and_( CrawlTask.mode == CrawlMode.TARGETED, CrawlTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]), text("config->>'stock_code' = :stock_code").bindparams(stock_code=code) ) ).order_by(desc(CrawlTask.created_at)).limit(1) ) existing_task = running_task.scalar_one_or_none() if existing_task: return TargetedCrawlResponse( success=False, message=f"该股票已有正在进行的爬取任务 (ID: {existing_task.id})", task_id=existing_task.id, celery_task_id=existing_task.celery_task_id ) logger.info(f"触发定向爬取任务: {request.stock_name}({code}), 时间范围: {request.days}天") # 先在数据库中创建任务记录(PENDING状态),这样前端轮询时能立即看到 task_record = CrawlTask( mode=CrawlMode.TARGETED, status=TaskStatus.PENDING, source="targeted", config={ "stock_code": code, "stock_name": request.stock_name, "days": request.days, }, ) db.add(task_record) await db.commit() await db.refresh(task_record) # 触发 Celery 任务,传入任务记录ID celery_task = targeted_stock_crawl_task.apply_async( args=(code, request.stock_name, request.days, task_record.id) ) # 更新 celery_task_id task_record.celery_task_id = celery_task.id await db.commit() return TargetedCrawlResponse( success=True, message=f"定向爬取任务已启动: {request.stock_name}({code})", task_id=task_record.id, celery_task_id=celery_task.id ) except Exception as e: logger.error(f"Failed to start targeted crawl for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{stock_code}/targeted-crawl/status", response_model=TargetedCrawlStatus) async def get_targeted_crawl_status( stock_code: str, db: AsyncSession = Depends(get_db) ): """ 查询定向爬取任务状态 - **stock_code**: 股票代码 """ try: # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" # 查询最近的定向爬取任务 task_query = select(CrawlTask).where( and_( CrawlTask.mode == CrawlMode.TARGETED, text("config->>'stock_code' = :stock_code").bindparams(stock_code=code) ) ).order_by(desc(CrawlTask.created_at)).limit(1) result = await db.execute(task_query) task = result.scalar_one_or_none() if not task: return TargetedCrawlStatus( status="idle", progress=None ) # 检测超时:如果任务在 PENDING 状态超过 5 分钟,自动标记为失败 if task.status == TaskStatus.PENDING and task.created_at: pending_duration = datetime.utcnow() - task.created_at if pending_duration > timedelta(minutes=5): logger.warning(f"Task {task.id} has been PENDING for {pending_duration}, marking as FAILED (timeout)") task.status = TaskStatus.FAILED task.completed_at = datetime.utcnow() task.error_message = "任务超时:Celery worker 可能未启动或已停止" await db.commit() # 检测运行超时:如果任务在 RUNNING 状态超过 30 分钟,也标记为失败 if task.status == TaskStatus.RUNNING and task.started_at: running_duration = datetime.utcnow() - task.started_at if running_duration > timedelta(minutes=30): logger.warning(f"Task {task.id} has been RUNNING for {running_duration}, marking as FAILED (timeout)") task.status = TaskStatus.FAILED task.completed_at = datetime.utcnow() task.error_message = "任务执行超时" await db.commit() return TargetedCrawlStatus( task_id=task.id, status=task.status, celery_task_id=task.celery_task_id, progress=task.progress, crawled_count=task.crawled_count, saved_count=task.saved_count, error_message=task.error_message, execution_time=task.execution_time, started_at=task.started_at.isoformat() if task.started_at else None, completed_at=task.completed_at.isoformat() if task.completed_at else None ) except Exception as e: logger.error(f"Failed to get targeted crawl status for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/{stock_code}/targeted-crawl/cancel") async def cancel_targeted_crawl( stock_code: str, db: AsyncSession = Depends(get_db) ): """ 取消定向爬取任务 - **stock_code**: 股票代码 """ try: # 标准化股票代码 code = stock_code.upper() if not (code.startswith("SH") or code.startswith("SZ")): code = f"SH{code}" if code.startswith("6") else f"SZ{code}" # 查找正在进行的任务 task_query = select(CrawlTask).where( and_( CrawlTask.mode == CrawlMode.TARGETED, CrawlTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]), text("config->>'stock_code' = :stock_code").bindparams(stock_code=code) ) ).order_by(desc(CrawlTask.created_at)).limit(1) result = await db.execute(task_query) task = result.scalar_one_or_none() if not task: return { "success": True, "message": "没有正在进行的任务" } # 更新任务状态为已取消 task.status = TaskStatus.CANCELLED task.completed_at = datetime.utcnow() task.error_message = "用户手动取消" await db.commit() # 如果有 celery_task_id,尝试撤销 Celery 任务 if task.celery_task_id: try: from ...tasks.crawl_tasks import celery_app celery_app.control.revoke(task.celery_task_id, terminate=True) logger.info(f"Revoked Celery task: {task.celery_task_id}") except Exception as e: logger.warning(f"Failed to revoke Celery task: {e}") logger.info(f"Cancelled targeted crawl task {task.id} for {code}") return { "success": True, "message": f"已取消任务 (ID: {task.id})", "task_id": task.id } except Exception as e: logger.error(f"Failed to cancel targeted crawl for {stock_code}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/cache/clear") async def clear_stock_data_cache( pattern: Optional[str] = Query(None, description="缓存键模式,如 'kline' 或 '002837'") ): """ 清除股票数据缓存 - **pattern**: 可选的缓存键模式,如果不提供则清除所有缓存 Examples: - `POST /api/v1/stocks/cache/clear` - 清除所有缓存 - `POST /api/v1/stocks/cache/clear?pattern=kline` - 只清除K线缓存 - `POST /api/v1/stocks/cache/clear?pattern=002837` - 只清除特定股票的缓存 """ try: stock_data_service.clear_cache(pattern) return { "success": True, "message": f"Cache cleared successfully" + (f" (pattern: {pattern})" if pattern else " (all)") } except Exception as e: logger.error(f"Failed to clear cache: {e}") raise HTTPException(status_code=500, detail=str(e)) ================================================ FILE: backend/app/api/v1/tasks.py ================================================ """ 任务管理 API 路由 """ import logging from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc from datetime import datetime from ...core.database import get_db from ...models.crawl_task import CrawlTask, CrawlMode, TaskStatus from ...tasks.crawl_tasks import cold_start_crawl_task, realtime_crawl_task logger = logging.getLogger(__name__) router = APIRouter() # Pydantic 模型 class TaskResponse(BaseModel): """任务响应模型""" model_config = {"from_attributes": True} id: int celery_task_id: Optional[str] = None mode: str status: str source: str config: Optional[dict] = None progress: Optional[dict] = None current_page: Optional[int] = None total_pages: Optional[int] = None result: Optional[dict] = None crawled_count: int saved_count: int error_message: Optional[str] = None execution_time: Optional[float] = None created_at: str started_at: Optional[str] = None completed_at: Optional[str] = None class ColdStartRequest(BaseModel): """冷启动请求模型""" source: str = Field(default="sina", description="新闻源") start_page: int = Field(default=1, ge=1, description="起始页码") end_page: int = Field(default=50, ge=1, le=100, description="结束页码") class ColdStartResponse(BaseModel): """冷启动响应模型""" success: bool message: str task_id: Optional[int] = None celery_task_id: Optional[str] = None class RealtimeCrawlRequest(BaseModel): """实时爬取请求模型""" source: str = Field(description="新闻源(sina, tencent, eeo等)") force_refresh: bool = Field(default=False, description="是否强制刷新(跳过缓存)") class RealtimeCrawlResponse(BaseModel): """实时爬取响应模型""" success: bool message: str celery_task_id: Optional[str] = None # API 端点 @router.get("/", response_model=List[TaskResponse]) async def get_tasks_list( skip: int = Query(0, ge=0, description="跳过的记录数"), limit: int = Query(20, ge=1, le=100, description="返回的记录数"), mode: Optional[str] = Query(None, description="按模式筛选"), status: Optional[str] = Query(None, description="按状态筛选"), db: AsyncSession = Depends(get_db) ): """ 获取任务列表 - **skip**: 跳过的记录数(分页) - **limit**: 返回的记录数 - **mode**: 按模式筛选(cold_start, realtime, targeted) - **status**: 按状态筛选(pending, running, completed, failed) """ try: query = select(CrawlTask).order_by(desc(CrawlTask.created_at)) if mode: query = query.where(CrawlTask.mode == mode) if status: query = query.where(CrawlTask.status == status) query = query.offset(skip).limit(limit) result = await db.execute(query) tasks = result.scalars().all() return [TaskResponse(**task.to_dict()) for task in tasks] except Exception as e: logger.error(f"Failed to get tasks list: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/{task_id}", response_model=TaskResponse) async def get_task_detail( task_id: int, db: AsyncSession = Depends(get_db) ): """ 获取任务详情 - **task_id**: 任务ID """ try: result = await db.execute( select(CrawlTask).where(CrawlTask.id == task_id) ) task = result.scalar_one_or_none() if not task: raise HTTPException(status_code=404, detail="Task not found") return TaskResponse(**task.to_dict()) except HTTPException: raise except Exception as e: logger.error(f"Failed to get task {task_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/cold-start", response_model=ColdStartResponse) async def trigger_cold_start( request: ColdStartRequest, db: AsyncSession = Depends(get_db) ): """ 触发冷启动批量爬取任务 - **source**: 新闻源(sina, jrj等) - **start_page**: 起始页码 - **end_page**: 结束页码 """ try: logger.info( f"触发冷启动任务: {request.source}, " f"页码 {request.start_page}-{request.end_page}" ) # 触发 Celery 任务 celery_task = cold_start_crawl_task.apply_async( args=(request.source, request.start_page, request.end_page) ) # 等待任务记录创建(最多等待2秒) await db.commit() # 确保之前的事务已提交 return ColdStartResponse( success=True, message=f"冷启动任务已启动: {request.source}, 页码 {request.start_page}-{request.end_page}", celery_task_id=celery_task.id ) except Exception as e: logger.error(f"Failed to trigger cold start: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/realtime", response_model=RealtimeCrawlResponse) async def trigger_realtime_crawl( request: RealtimeCrawlRequest, db: AsyncSession = Depends(get_db) ): """ 手动触发实时爬取任务 - **source**: 新闻源(sina, tencent, eeo, jwview等) - **force_refresh**: 是否强制刷新(跳过缓存) 示例: - POST /api/v1/tasks/realtime {"source": "tencent", "force_refresh": true} - POST /api/v1/tasks/realtime {"source": "eeo"} """ try: logger.info( f"手动触发实时爬取任务: {request.source}, " f"force_refresh={request.force_refresh}" ) # 触发 Celery 任务 celery_task = realtime_crawl_task.apply_async( args=(request.source, request.force_refresh) ) return RealtimeCrawlResponse( success=True, message=f"实时爬取任务已启动: {request.source}", celery_task_id=celery_task.id ) except Exception as e: logger.error(f"Failed to trigger realtime crawl: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/stats/summary") async def get_task_stats( db: AsyncSession = Depends(get_db) ): """ 获取任务统计信息 """ try: # 统计各状态的任务数 result = await db.execute(select(CrawlTask)) all_tasks = result.scalars().all() stats = { "total": len(all_tasks), "by_status": {}, "by_mode": {}, "recent_completed": 0, "total_news_crawled": 0, "total_news_saved": 0, } for task in all_tasks: # 按状态统计 stats["by_status"][task.status] = stats["by_status"].get(task.status, 0) + 1 # 按模式统计 stats["by_mode"][task.mode] = stats["by_mode"].get(task.mode, 0) + 1 # 统计新闻数 stats["total_news_crawled"] += task.crawled_count or 0 stats["total_news_saved"] += task.saved_count or 0 # 最近24小时完成的任务 if task.status == TaskStatus.COMPLETED and task.completed_at: from datetime import timedelta if datetime.utcnow() - task.completed_at < timedelta(days=1): stats["recent_completed"] += 1 return stats except Exception as e: logger.error(f"Failed to get task stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/{task_id}") async def delete_task( task_id: int, db: AsyncSession = Depends(get_db) ): """ 删除任务记录 - **task_id**: 任务ID """ try: result = await db.execute( select(CrawlTask).where(CrawlTask.id == task_id) ) task = result.scalar_one_or_none() if not task: raise HTTPException(status_code=404, detail="Task not found") await db.delete(task) await db.commit() return {"success": True, "message": f"Task {task_id} deleted"} except HTTPException: raise except Exception as e: logger.error(f"Failed to delete task {task_id}: {e}") await db.rollback() raise HTTPException(status_code=500, detail=str(e)) ================================================ FILE: backend/app/config/__init__.py ================================================ """ 配置模块 """ import os from pathlib import Path from typing import Dict, Any, Optional, List import yaml from pydantic import BaseModel, Field, ConfigDict # 配置目录 CONFIG_DIR = Path(__file__).parent class AgentConfig(BaseModel): """智能体配置""" name: str role: str description: str class FlowStep(BaseModel): """流程步骤配置""" name: str description: str parallel: bool = False agents: List[str] = Field(default_factory=list) type: Optional[str] = None max_rounds: Optional[int] = None class FlowConfig(BaseModel): """流程配置""" type: str steps: List[FlowStep] class ModeRules(BaseModel): """模式规则配置""" max_time: int = 300 max_rounds: Optional[int] = None round_time_limit: Optional[int] = None manager_can_interrupt: bool = False require_news: bool = True require_financial: bool = True require_data_collection: bool = False early_decision: bool = False min_news_count: int = 0 class DebateRules(BaseModel): """辩论规则配置""" opening_statement: bool = True rebuttal_required: bool = True evidence_required: bool = True interrupt_cooldown: int = 30 class DebateModeConfig(BaseModel): """辩论模式配置""" name: str description: str icon: str = "📊" agents: List[AgentConfig] flow: FlowConfig rules: ModeRules debate_rules: Optional[DebateRules] = None class LLMConfig(BaseModel): """LLM配置""" default_provider: str = "bailian" default_model: str = "qwen-plus" temperature: float = 0.7 max_tokens: int = 4096 class DataSourceConfig(BaseModel): """数据源配置""" type: str priority: int = 1 class DataSourcesConfig(BaseModel): """数据源集合配置""" news: List[DataSourceConfig] = Field(default_factory=list) financial: List[DataSourceConfig] = Field(default_factory=list) class OutputConfig(BaseModel): """输出配置""" format: str = "markdown" include_trajectory: bool = True include_timestamps: bool = True class GlobalConfig(BaseModel): """全局配置""" llm: LLMConfig = Field(default_factory=LLMConfig) data_sources: DataSourcesConfig = Field(default_factory=DataSourcesConfig) output: OutputConfig = Field(default_factory=OutputConfig) class DebateModesConfig(BaseModel): """辩论模式总配置""" model_config = ConfigDict(populate_by_name=True) default_mode: str = "parallel" modes: Dict[str, DebateModeConfig] global_config: GlobalConfig = Field(default_factory=GlobalConfig, alias="global") def load_debate_modes_config() -> DebateModesConfig: """加载辩论模式配置""" config_file = CONFIG_DIR / "debate_modes.yaml" if not config_file.exists(): raise FileNotFoundError(f"配置文件不存在: {config_file}") with open(config_file, "r", encoding="utf-8") as f: raw_config = yaml.safe_load(f) # 处理 global 关键字冲突 if "global" in raw_config: raw_config["global_config"] = raw_config.pop("global") return DebateModesConfig(**raw_config) def get_mode_config(mode_name: str) -> Optional[DebateModeConfig]: """获取指定模式的配置""" config = load_debate_modes_config() return config.modes.get(mode_name) def get_available_modes() -> List[Dict[str, Any]]: """获取所有可用的模式列表""" config = load_debate_modes_config() modes = [] for mode_id, mode_config in config.modes.items(): modes.append({ "id": mode_id, "name": mode_config.name, "description": mode_config.description, "icon": mode_config.icon, "is_default": mode_id == config.default_mode }) return modes def get_default_mode() -> str: """获取默认模式""" config = load_debate_modes_config() return config.default_mode # 单例缓存 _cached_config: Optional[DebateModesConfig] = None def get_cached_config() -> DebateModesConfig: """获取缓存的配置(避免重复读取文件)""" global _cached_config if _cached_config is None: _cached_config = load_debate_modes_config() return _cached_config def reload_config() -> DebateModesConfig: """重新加载配置""" global _cached_config _cached_config = load_debate_modes_config() return _cached_config ================================================ FILE: backend/app/config/debate_modes.yaml ================================================ # 多智能体协作模式配置 # 支持多种辩论/分析模式,可通过前端或API选择 # 默认模式 default_mode: parallel modes: # ============ 并行分析模式(当前默认) ============ parallel: name: "并行分析模式" description: "Bull/Bear并行分析,投资经理汇总决策" icon: "⚡" # 参与的智能体 agents: - name: BullResearcher role: "看多研究员" description: "从积极角度分析股票,发现投资机会" - name: BearResearcher role: "看空研究员" description: "从风险角度分析股票,识别潜在问题" - name: InvestmentManager role: "投资经理" description: "综合双方观点,做出最终投资决策" # 执行流程 flow: type: parallel_then_summarize steps: - name: data_preparation description: "准备新闻和财务数据" parallel: false - name: researcher_analysis description: "Bull/Bear并行分析" parallel: true agents: [BullResearcher, BearResearcher] - name: manager_decision description: "投资经理综合决策" parallel: false agents: [InvestmentManager] # 规则配置 rules: max_time: 300 # 最长执行时间(秒) require_news: true # 是否需要新闻数据 require_financial: true # 是否需要财务数据 min_news_count: 1 # 最少新闻数量 # ============ 实时辩论模式 ============ realtime_debate: name: "实时辩论模式" description: "四人实时对话,投资经理主持,多空双方交替发言" icon: "🎭" # 参与的智能体 agents: - name: DataCollector role: "数据专员" description: "搜集和整理相关数据资料" - name: BullResearcher role: "多方辩手" description: "支持买入,提出看多论点" - name: BearResearcher role: "空方辩手" description: "建议卖出,提出看空论点" - name: InvestmentManager role: "投资经理(主持人)" description: "主持辩论,随时提问,最终裁决" # 执行流程 flow: type: orchestrated_debate steps: - name: opening description: "投资经理开场,下发分析任务" agents: [InvestmentManager] - name: data_collection description: "数据专员搜集资料" agents: [DataCollector] - name: debate_rounds description: "多空双方辩论" type: alternating agents: [BullResearcher, BearResearcher] max_rounds: 5 - name: closing description: "投资经理总结决策" agents: [InvestmentManager] # 规则配置 rules: max_rounds: 5 # 最大辩论回合数 max_time: 600 # 最长执行时间(秒) round_time_limit: 60 # 每回合时间限制(秒) manager_can_interrupt: true # 投资经理是否可以打断 require_data_collection: true # 是否需要先搜集数据 early_decision: true # 是否允许提前做决策 # 辩论规则 debate_rules: opening_statement: true # 是否需要开场陈述 rebuttal_required: true # 是否必须反驳对方 evidence_required: true # 是否需要提供证据 interrupt_cooldown: 30 # 打断冷却时间(秒) # ============ 快速分析模式 ============ quick_analysis: name: "快速分析模式" description: "单一分析师快速给出建议,适合时间紧迫场景" icon: "🚀" agents: - name: QuickAnalyst role: "快速分析师" description: "综合多角度快速给出投资建议" flow: type: single_agent steps: - name: quick_analysis description: "快速综合分析" agents: [QuickAnalyst] rules: max_time: 60 require_news: false require_financial: true # ============ 全局配置 ============ global: # LLM配置 llm: default_provider: bailian default_model: qwen-plus temperature: 0.7 max_tokens: 4096 # 数据源配置 data_sources: news: - type: database priority: 1 - type: bochaai priority: 2 financial: - type: akshare priority: 1 # 输出配置 output: format: markdown include_trajectory: true include_timestamps: true ================================================ FILE: backend/app/core/__init__.py ================================================ """ 核心模块 """ from .config import settings, get_settings from .database import get_db, init_database __all__ = ["settings", "get_settings", "get_db", "init_database"] ================================================ FILE: backend/app/core/celery_app.py ================================================ """ Celery 应用配置 """ from celery import Celery from celery.schedules import crontab from .config import settings # 创建 Celery 应用 celery_app = Celery( "finnews", broker=settings.REDIS_URL, backend=settings.REDIS_URL, include=["app.tasks.crawl_tasks"] # 导入任务模块 ) # Celery 配置 celery_app.conf.update( # 时区设置 timezone="Asia/Shanghai", enable_utc=True, # 任务结果配置 result_expires=3600, # 结果保存1小时 result_backend_transport_options={ 'master_name': 'mymaster' }, # 任务执行配置 task_serializer="json", result_serializer="json", accept_content=["json"], task_track_started=True, task_time_limit=30 * 60, # 30分钟超时 task_soft_time_limit=25 * 60, # 25分钟软超时 # Worker 配置 worker_prefetch_multiplier=1, # 每次只拿一个任务 worker_max_tasks_per_child=1000, # 每个 worker 处理1000个任务后重启 # Beat 调度配置 beat_schedule={ # 每1分钟爬取新浪财经 "crawl-sina-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("sina",), }, # 每1分钟爬取腾讯财经 "crawl-tencent-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("tencent",), }, # 每1分钟爬取中新经纬 "crawl-jwview-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("jwview",), }, # 每1分钟爬取经济观察网 "crawl-eeo-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("eeo",), }, # 每1分钟爬取财经网 "crawl-caijing-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("caijing",), }, # 每1分钟爬取21经济网 "crawl-jingji21-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("jingji21",), }, # 每1分钟爬取每日经济新闻 "crawl-nbd-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("nbd",), }, # 每1分钟爬取第一财经 "crawl-yicai-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("yicai",), }, # 每1分钟爬取网易财经 "crawl-163-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("163",), }, # 每1分钟爬取东方财富 "crawl-eastmoney-every-1min": { "task": "app.tasks.crawl_tasks.realtime_crawl_task", "schedule": crontab(minute="*/1"), "args": ("eastmoney",), }, }, ) # 任务路由(可选,用于任务分发) # 注释掉自定义路由,使用默认的 celery 队列 # celery_app.conf.task_routes = { # "app.tasks.crawl_tasks.*": {"queue": "crawl"}, # "app.tasks.analysis_tasks.*": {"queue": "analysis"}, # } if __name__ == "__main__": celery_app.start() ================================================ FILE: backend/app/core/config.py ================================================ """ FinnewsHunter 核心配置模块 使用 Pydantic Settings 管理环境变量和配置 """ from typing import Optional, List from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): """应用配置类""" # 应用基础配置 APP_NAME: str = "FinnewsHunter" APP_VERSION: str = "0.1.0" API_V1_PREFIX: str = "/api/v1" DEBUG: bool = Field(default=True) # 服务器配置 HOST: str = Field(default="0.0.0.0") PORT: int = Field(default=8000) # CORS 配置 BACKEND_CORS_ORIGINS: List[str] = Field( default=["http://localhost:3000", "http://localhost:8000"] ) # PostgreSQL 数据库配置 POSTGRES_USER: str = Field(default="finnews") POSTGRES_PASSWORD: str = Field(default="finnews_dev_password") POSTGRES_HOST: str = Field(default="localhost") POSTGRES_PORT: int = Field(default=5432) POSTGRES_DB: str = Field(default="finnews_db") @property def DATABASE_URL(self) -> str: """异步数据库连接 URL""" return ( f"postgresql+asyncpg://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}" f"@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}" ) @property def SYNC_DATABASE_URL(self) -> str: """同步数据库连接 URL(用于初始化)""" return ( f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}" f"@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}" ) # Redis 配置 REDIS_HOST: str = Field(default="localhost") REDIS_PORT: int = Field(default=6379) REDIS_DB: int = Field(default=0) REDIS_PASSWORD: Optional[str] = Field(default=None) @property def REDIS_URL(self) -> str: """Redis 连接 URL""" if self.REDIS_PASSWORD: return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}" return f"redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}" # Milvus 配置 MILVUS_HOST: str = Field(default="localhost") MILVUS_PORT: int = Field(default=19530) MILVUS_COLLECTION_NAME: str = Field(default="finnews_embeddings") MILVUS_DIM: int = Field(default=1536) # OpenAI embedding dimension # Neo4j 知识图谱配置 NEO4J_URI: str = Field(default="bolt://localhost:7687", description="Neo4j 连接URI") NEO4J_USER: str = Field(default="neo4j", description="Neo4j 用户名") NEO4J_PASSWORD: str = Field(default="finnews_neo4j_password", description="Neo4j 密码") # LLM 配置 LLM_PROVIDER: str = Field(default="bailian") # 默认提供商 LLM_MODEL: str = Field(default="qwen-plus") LLM_TEMPERATURE: float = Field(default=0.7) LLM_MAX_TOKENS: int = Field(default=2000) LLM_TIMEOUT: int = Field(default=180) # LLM 调用超时时间(秒),百炼建议180秒 # 各厂商 API Key 配置 DASHSCOPE_API_KEY: Optional[str] = Field(default=None, description="阿里云百炼 API Key") DASHSCOPE_BASE_URL: str = Field( default="https://dashscope.aliyuncs.com/compatible-mode/v1", description="阿里云百炼 Base URL" ) BAILIAN_API_KEY: Optional[str] = Field(default=None, description="百炼 API Key(与DASHSCOPE相同)") OPENAI_API_KEY: Optional[str] = Field(default=None, description="OpenAI API Key") DEEPSEEK_API_KEY: Optional[str] = Field(default=None, description="DeepSeek API Key") MOONSHOT_API_KEY: Optional[str] = Field(default=None, description="Moonshot (Kimi) API Key") ZHIPU_API_KEY: Optional[str] = Field(default=None, description="智谱 API Key") ANTHROPIC_API_KEY: Optional[str] = Field(default=None, description="Anthropic API Key") # 各厂商可用模型列表(逗号分隔) BAILIAN_MODELS: str = Field( default="qwen-plus,qwen-max,qwen-turbo,qwen-long", description="百炼可用模型(逗号分隔)" ) OPENAI_MODELS: str = Field( default="gpt-4,gpt-4-turbo,gpt-3.5-turbo", description="OpenAI可用模型(逗号分隔)" ) DEEPSEEK_MODELS: str = Field( default="deepseek-chat", description="DeepSeek可用模型(逗号分隔)" ) MOONSHOT_MODELS: str = Field( default="moonshot-v1-8k,moonshot-v1-32k,moonshot-v1-128k", description="Moonshot可用模型(逗号分隔)" ) ZHIPU_MODELS: str = Field( default="glm-4,glm-4-plus,glm-4-air,glm-3-turbo", description="智谱可用模型(逗号分隔)" ) # Base URL 配置(用于第三方 API 转发) OPENAI_BASE_URL: Optional[str] = Field(default=None, description="OpenAI Base URL") DEEPSEEK_BASE_URL: Optional[str] = Field(default="https://api.deepseek.com/v1", description="DeepSeek Base URL") MOONSHOT_BASE_URL: Optional[str] = Field(default="https://api.moonshot.cn/v1", description="Moonshot Base URL") ZHIPU_BASE_URL: Optional[str] = Field(default="https://open.bigmodel.cn/api/paas/v4", description="智谱 Base URL") ANTHROPIC_BASE_URL: Optional[str] = Field(default=None, description="Anthropic Base URL") QWEN_BASE_URL: Optional[str] = Field(default=None, description="Qwen Base URL (deprecated)") BAILIAN_ACCESS_KEY_ID: Optional[str] = Field(default=None, description="百炼 Access Key ID") BAILIAN_ACCESS_KEY_SECRET: Optional[str] = Field(default=None, description="百炼 Access Key Secret") BAILIAN_AGENT_CODE: Optional[str] = Field(default=None, description="百炼 Agent Code") BAILIAN_REGION_ID: str = Field(default="cn-beijing", description="百炼 Region ID") # BochaAI 搜索 API 配置 BOCHAAI_API_KEY: Optional[str] = Field(default=None, description="BochaAI Web Search API Key") BOCHAAI_ENDPOINT: str = Field(default="https://api.bochaai.com/v1/web-search", description="BochaAI API Endpoint") # Embedding 配置 EMBEDDING_PROVIDER: str = Field(default="openai") # openai, huggingface EMBEDDING_MODEL: str = Field(default="text-embedding-ada-002") EMBEDDING_BATCH_SIZE: int = Field(default=100) EMBEDDING_BASE_URL: Optional[str] = Field(default=None) # 自定义 Embedding API 端点 EMBEDDING_TIMEOUT: int = Field(default=30, description="Embedding API 超时时间(秒),建议设置为20-30秒") EMBEDDING_MAX_RETRIES: int = Field(default=2, description="Embedding API 最大重试次数,建议设置为1-2次以避免等待太久") # 爬虫配置 CRAWLER_USER_AGENT: str = Field( default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" ) CRAWLER_TIMEOUT: int = Field(default=30) CRAWLER_MAX_RETRIES: int = Field(default=3) CRAWLER_DELAY: float = Field(default=1.0) # 请求间隔(秒) # Phase 2: 实时爬取与缓存配置(多源支持) CACHE_TTL: int = Field(default=1800, description="缓存过期时间(秒),默认30分钟") CRAWL_INTERVAL_SINA: int = Field(default=60, description="新浪财经爬取间隔(秒),默认60秒") CRAWL_INTERVAL_TENCENT: int = Field(default=60, description="腾讯财经爬取间隔(秒),默认60秒") CRAWL_INTERVAL_JWVIEW: int = Field(default=60, description="中新经纬爬取间隔(秒),默认60秒") CRAWL_INTERVAL_EEO: int = Field(default=60, description="经济观察网爬取间隔(秒),默认60秒") CRAWL_INTERVAL_CAIJING: int = Field(default=60, description="财经网爬取间隔(秒),默认60秒") CRAWL_INTERVAL_JINGJI21: int = Field(default=60, description="21经济网爬取间隔(秒),默认60秒") CRAWL_INTERVAL_JRJ: int = Field(default=600, description="金融界爬取间隔(秒),默认10分钟") NEWS_RETENTION_HOURS: int = Field(default=72000, description="新闻保留时间(小时),临时设置为72000小时(约8年)以包含所有爬取的新闻") FRONTEND_REFETCH_INTERVAL: int = Field(default=180, description="前端自动刷新间隔(秒),默认3分钟") # 日志配置 LOG_LEVEL: str = Field(default="INFO") LOG_FILE: Optional[str] = Field(default="logs/finnews.log") # 安全配置 SECRET_KEY: str = Field(default="your-secret-key-here-change-in-production") ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=60 * 24 * 7) # 7 days # 业务配置 MAX_NEWS_PER_REQUEST: int = Field(default=50) NEWS_CACHE_TTL: int = Field(default=3600) # 1 hour model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore", env_ignore_empty=True, ) # 全局配置实例 settings = Settings() # 便捷访问函数 def get_settings() -> Settings: """获取配置实例(用于依赖注入)""" return settings ================================================ FILE: backend/app/core/database.py ================================================ """ 数据库连接和依赖注入 """ from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession from ..models.database import ( AsyncSessionLocal, init_db as create_tables, Base, ) async def get_db() -> AsyncGenerator[AsyncSession, None]: """ FastAPI 依赖注入:获取数据库会话 Usage: @app.get("/items") async def get_items(db: AsyncSession = Depends(get_db)): ... Yields: AsyncSession: 数据库会话 """ async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() def init_database(): """ 初始化数据库 创建所有表结构 """ print("=" * 50) print("Initializing FinnewsHunter Database...") print("=" * 50) try: create_tables() print("\n✓ Database initialization completed successfully!") except Exception as e: print(f"\n✗ Database initialization failed: {e}") raise if __name__ == "__main__": # 直接运行此文件以初始化数据库 init_database() ================================================ FILE: backend/app/core/neo4j_client.py ================================================ """ Neo4j 图数据库客户端 用于存储和查询公司知识图谱 """ import logging from typing import Optional, Dict, List, Any from neo4j import GraphDatabase, Driver from contextlib import contextmanager from .config import settings logger = logging.getLogger(__name__) class Neo4jClient: """Neo4j 客户端封装""" def __init__( self, uri: str = None, user: str = None, password: str = None ): """ 初始化 Neo4j 客户端 Args: uri: Neo4j URI(如 bolt://localhost:7687) user: 用户名 password: 密码 """ self.uri = uri or settings.NEO4J_URI or "bolt://localhost:7687" self.user = user or settings.NEO4J_USER or "neo4j" self.password = password or settings.NEO4J_PASSWORD or "finnews_neo4j_password" self._driver: Optional[Driver] = None self._connected = False def connect(self): """建立连接""" if self._connected: return try: self._driver = GraphDatabase.driver( self.uri, auth=(self.user, self.password) ) # 测试连接 self._driver.verify_connectivity() self._connected = True logger.info(f"✅ Neo4j 连接成功: {self.uri}") except Exception as e: logger.error(f"❌ Neo4j 连接失败: {e}") raise def close(self): """关闭连接""" if self._driver: self._driver.close() self._connected = False logger.info("Neo4j 连接已关闭") @contextmanager def session(self): """获取会话(上下文管理器)""" if not self._connected: self.connect() session = self._driver.session() try: yield session finally: session.close() def execute_query( self, query: str, parameters: Dict[str, Any] = None ) -> List[Dict[str, Any]]: """ 执行 Cypher 查询 Args: query: Cypher 查询语句 parameters: 查询参数 Returns: 查询结果列表 """ with self.session() as session: result = session.run(query, parameters or {}) return [dict(record) for record in result] def execute_write( self, query: str, parameters: Dict[str, Any] = None ) -> List[Dict[str, Any]]: """ 执行写入操作 Args: query: Cypher 写入语句 parameters: 参数 Returns: 写入结果 """ with self.session() as session: result = session.run(query, parameters or {}) return [dict(record) for record in result] def is_connected(self) -> bool: """检查连接状态""" return self._connected def health_check(self) -> bool: """健康检查""" try: if not self._connected: self.connect() with self.session() as session: result = session.run("RETURN 1 as health") return result.single()["health"] == 1 except Exception as e: logger.error(f"Neo4j 健康检查失败: {e}") return False # 全局单例 _neo4j_client: Optional[Neo4jClient] = None def get_neo4j_client() -> Neo4jClient: """获取 Neo4j 客户端单例""" global _neo4j_client if _neo4j_client is None: _neo4j_client = Neo4jClient() _neo4j_client.connect() return _neo4j_client def close_neo4j_client(): """关闭 Neo4j 客户端""" global _neo4j_client if _neo4j_client: _neo4j_client.close() _neo4j_client = None ================================================ FILE: backend/app/core/redis_client.py ================================================ """ Redis Client for Caching and Task Queue """ import json import logging from typing import Optional, Any from datetime import datetime, timedelta import redis from app.core.config import settings logger = logging.getLogger(__name__) class RedisClient: """Redis client wrapper with JSON serialization support""" def __init__(self): try: self.client = redis.Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, decode_responses=True, # 自动解码为字符串 socket_connect_timeout=5, socket_timeout=5, ) # 测试连接 self.client.ping() logger.info(f"✅ Redis connected: {settings.REDIS_HOST}:{settings.REDIS_PORT}") except Exception as e: logger.error(f"❌ Redis connection failed: {e}") self.client = None def is_available(self) -> bool: """检查 Redis 是否可用""" try: if self.client: self.client.ping() return True except: pass return False def get_json(self, key: str) -> Optional[Any]: """获取 JSON 数据""" if not self.is_available(): return None try: value = self.client.get(key) if value: return json.loads(value) except Exception as e: logger.error(f"Redis get_json error: {e}") return None def set_json(self, key: str, value: Any, ttl: int = None) -> bool: """存储 JSON 数据""" if not self.is_available(): return False try: json_str = json.dumps(value, ensure_ascii=False, default=str) if ttl: self.client.setex(key, ttl, json_str) else: self.client.set(key, json_str) return True except Exception as e: logger.error(f"Redis set_json error: {e}") return False def get(self, key: str) -> Optional[str]: """获取字符串数据""" if not self.is_available(): return None try: return self.client.get(key) except Exception as e: logger.error(f"Redis get error: {e}") return None def set(self, key: str, value: str, ttl: int = None) -> bool: """存储字符串数据""" if not self.is_available(): return False try: if ttl: self.client.setex(key, ttl, value) else: self.client.set(key, value) return True except Exception as e: logger.error(f"Redis set error: {e}") return False def delete(self, key: str) -> bool: """删除键""" if not self.is_available(): return False try: self.client.delete(key) return True except Exception as e: logger.error(f"Redis delete error: {e}") return False def exists(self, key: str) -> bool: """检查键是否存在""" if not self.is_available(): return False try: return self.client.exists(key) > 0 except Exception as e: logger.error(f"Redis exists error: {e}") return False def get_cache_metadata(self, key: str) -> Optional[dict]: """获取缓存元数据(时间戳)""" time_key = f"{key}:timestamp" timestamp_str = self.get(time_key) if timestamp_str: try: return { "timestamp": datetime.fromisoformat(timestamp_str), "age_seconds": (datetime.now() - datetime.fromisoformat(timestamp_str)).total_seconds() } except: pass return None def set_with_metadata(self, key: str, value: Any, ttl: int = None) -> bool: """存储数据并记录时间戳""" success = self.set_json(key, value, ttl) if success: time_key = f"{key}:timestamp" self.set(time_key, datetime.now().isoformat(), ttl) return success def clear_pattern(self, pattern: str) -> int: """清除匹配模式的所有键""" if not self.is_available(): return 0 try: keys = self.client.keys(pattern) if keys: return self.client.delete(*keys) except Exception as e: logger.error(f"Redis clear_pattern error: {e}") return 0 # 全局单例 redis_client = RedisClient() ================================================ FILE: backend/app/financial/__init__.py ================================================ """ FinnewsHunter 金融数据层 借鉴 OpenBB 的 Provider-Fetcher 架构,提供: 1. Standard Models: 统一的数据模型 (NewsData, StockPriceData 等) 2. Provider Registry: 多数据源管理与自动降级 3. AgenticX Tools: 封装为 Agent 可调用的工具 设计原则: - 不修改 AgenticX 核心,所有金融特定逻辑内化在本模块 - TET Pipeline: Transform Query → Extract Data → Transform Data - 多源降级: Provider 失败时自动切换到备用源 """ from .registry import get_registry, ProviderRegistry from .models.news import NewsQueryParams, NewsData, NewsSentiment from .models.stock import ( StockQueryParams, StockPriceData, KlineInterval, AdjustType ) __all__ = [ # Registry "get_registry", "ProviderRegistry", # News Models "NewsQueryParams", "NewsData", "NewsSentiment", # Stock Models "StockQueryParams", "StockPriceData", "KlineInterval", "AdjustType", ] ================================================ FILE: backend/app/financial/models/__init__.py ================================================ """ 金融数据标准模型 借鉴 OpenBB Standard Models 设计: - QueryParams: 定义标准输入参数 - Data: 定义标准输出字段 所有 Provider 的 Fetcher 都使用这些标准模型,确保数据格式一致。 """ from .news import NewsQueryParams, NewsData, NewsSentiment from .stock import StockQueryParams, StockPriceData, KlineInterval, AdjustType __all__ = [ "NewsQueryParams", "NewsData", "NewsSentiment", "StockQueryParams", "StockPriceData", "KlineInterval", "AdjustType", ] ================================================ FILE: backend/app/financial/models/news.py ================================================ """ 金融新闻标准模型 借鉴 OpenBB Standard Models 设计: - NewsQueryParams: 新闻查询参数标准模型 - NewsData: 新闻数据标准模型 所有 NewsProvider 的 Fetcher 都接收 NewsQueryParams 作为输入, 返回 List[NewsData] 作为输出,确保不同数据源返回的数据格式一致。 来源参考: - OpenBB: openbb_core.provider.standard_models - 设计文档: research/codedeepresearch/OpenBB/FinnewsHunter_improvement_plan.md """ from pydantic import BaseModel, Field from datetime import datetime from typing import Optional, List from enum import Enum import hashlib class NewsSentiment(str, Enum): """新闻情感标签""" POSITIVE = "positive" NEGATIVE = "negative" NEUTRAL = "neutral" class NewsQueryParams(BaseModel): """ 新闻查询参数标准模型 所有 NewsProvider 的 Fetcher 都接收此模型作为输入, 内部再转换为各自 API 的参数格式 (transform_query)。 Example: >>> params = NewsQueryParams(stock_codes=["600519"], limit=10) >>> fetcher.fetch(params) # 返回 List[NewsData] """ keywords: Optional[List[str]] = Field( default=None, description="搜索关键词列表" ) stock_codes: Optional[List[str]] = Field( default=None, description="关联股票代码列表,如 ['600519', '000001']" ) start_date: Optional[datetime] = Field( default=None, description="新闻发布时间起始" ) end_date: Optional[datetime] = Field( default=None, description="新闻发布时间截止" ) limit: int = Field( default=50, ge=1, le=500, description="返回条数上限" ) source_filter: Optional[List[str]] = Field( default=None, description="数据源过滤,如 ['sina', 'tencent']" ) class Config: json_schema_extra = { "example": { "stock_codes": ["600519", "000001"], "limit": 20, "keywords": ["茅台", "白酒"] } } class NewsData(BaseModel): """ 新闻数据标准模型 所有 Provider 返回的数据都必须转换为此模型, 确保上层 Agent 处理逻辑一致。 设计原则: - 必填字段: id, title, content, source, source_url, publish_time - 可选字段: summary, sentiment 等 (可由 LLM 后续填充) - extra 字段: 存储 Provider 特有的额外数据 """ id: str = Field(..., description="新闻唯一标识 (建议用 URL 的 MD5)") title: str = Field(..., description="新闻标题") content: str = Field(..., description="新闻正文") summary: Optional[str] = Field(default=None, description="摘要(可由 LLM 生成)") source: str = Field(..., description="来源网站名称,如 'sina', 'tencent'") source_url: str = Field(..., description="原文链接") publish_time: datetime = Field(..., description="发布时间") crawl_time: Optional[datetime] = Field( default_factory=datetime.now, description="抓取时间" ) # 关联信息 stock_codes: List[str] = Field( default_factory=list, description="关联股票代码,如 ['SH600519', 'SZ000001']" ) stock_names: List[str] = Field( default_factory=list, description="关联股票名称,如 ['贵州茅台', '平安银行']" ) # 情感分析(可选,由 Agent 或 LLM 填充) sentiment: Optional[NewsSentiment] = Field( default=None, description="情感标签" ) sentiment_score: Optional[float] = Field( default=None, ge=-1, le=1, description="情感分数:-1(极度负面) ~ 1(极度正面)" ) # 原始数据(可选) keywords: List[str] = Field( default_factory=list, description="关键词列表" ) author: Optional[str] = Field(default=None, description="作者") # 元数据 extra: dict = Field( default_factory=dict, description="Provider 特有的额外字段" ) class Config: json_encoders = { datetime: lambda v: v.isoformat() } json_schema_extra = { "example": { "id": "a1b2c3d4e5f6", "title": "贵州茅台2024年三季度业绩超预期", "content": "贵州茅台发布2024年三季度报告...", "source": "sina", "source_url": "https://finance.sina.com.cn/stock/...", "publish_time": "2024-10-30T10:30:00", "stock_codes": ["SH600519"], "sentiment": "positive", "sentiment_score": 0.8 } } @staticmethod def generate_id(url: str) -> str: """根据 URL 生成唯一 ID""" return hashlib.md5(url.encode()).hexdigest()[:16] def to_legacy_dict(self) -> dict: """ 转换为旧版 NewsItem 格式 (兼容现有代码) Returns: 与旧版 NewsItem.to_dict() 格式一致的字典 """ return { "title": self.title, "content": self.content, "url": self.source_url, "source": self.source, "publish_time": self.publish_time.isoformat() if self.publish_time else None, "author": self.author, "keywords": self.keywords, "stock_codes": self.stock_codes, "summary": self.summary, "raw_html": self.extra.get("raw_html"), } ================================================ FILE: backend/app/financial/models/stock.py ================================================ """ 股票数据标准模型 借鉴 OpenBB Standard Models 设计: - StockQueryParams: 股票数据查询参数 - StockPriceData: 股票价格数据 (K线) 来源参考: - OpenBB: openbb_core.provider.standard_models - 设计文档: research/codedeepresearch/OpenBB/FinnewsHunter_improvement_plan.md """ from pydantic import BaseModel, Field from datetime import date, datetime from typing import Optional, List from enum import Enum class KlineInterval(str, Enum): """K线周期""" MIN_1 = "1m" MIN_5 = "5m" MIN_15 = "15m" MIN_30 = "30m" MIN_60 = "60m" DAILY = "1d" WEEKLY = "1w" MONTHLY = "1M" class AdjustType(str, Enum): """复权类型""" NONE = "none" QFQ = "qfq" # 前复权 HFQ = "hfq" # 后复权 class StockQueryParams(BaseModel): """ 股票数据查询参数 Example: >>> params = StockQueryParams(symbol="600519", interval=KlineInterval.DAILY) >>> fetcher.fetch(params) # 返回 List[StockPriceData] """ symbol: str = Field(..., description="股票代码,如 '600519' 或 'SH600519'") start_date: Optional[date] = Field(default=None, description="开始日期") end_date: Optional[date] = Field(default=None, description="结束日期") interval: KlineInterval = Field( default=KlineInterval.DAILY, description="K线周期" ) adjust: AdjustType = Field( default=AdjustType.QFQ, description="复权类型" ) limit: int = Field( default=90, ge=1, le=1000, description="返回条数" ) class Config: json_schema_extra = { "example": { "symbol": "600519", "interval": "1d", "limit": 90, "adjust": "qfq" } } class StockPriceData(BaseModel): """ 股票价格数据(K线) 与现有 StockDataService 返回格式对齐, 确保迁移时的兼容性。 """ symbol: str = Field(..., description="股票代码") date: datetime = Field(..., description="交易时间") open: float = Field(..., description="开盘价") high: float = Field(..., description="最高价") low: float = Field(..., description="最低价") close: float = Field(..., description="收盘价") volume: int = Field(..., description="成交量") turnover: Optional[float] = Field(default=None, description="成交额") change_percent: Optional[float] = Field(default=None, description="涨跌幅 %") change_amount: Optional[float] = Field(default=None, description="涨跌额") amplitude: Optional[float] = Field(default=None, description="振幅 %") turnover_rate: Optional[float] = Field(default=None, description="换手率 %") class Config: json_encoders = { datetime: lambda v: v.isoformat() } def to_legacy_dict(self) -> dict: """ 转换为旧版 StockDataService 格式 (兼容现有代码) Returns: 与旧版 get_kline_data 返回格式一致的字典 """ return { "timestamp": int(self.date.timestamp() * 1000), "date": self.date.strftime("%Y-%m-%d") if self.date else None, "open": self.open, "high": self.high, "low": self.low, "close": self.close, "volume": self.volume, "turnover": self.turnover or 0, "change_percent": self.change_percent or 0, "change_amount": self.change_amount or 0, "amplitude": self.amplitude or 0, "turnover_rate": self.turnover_rate or 0, } class StockRealtimeData(BaseModel): """股票实时行情""" symbol: str name: str price: float change_percent: float change_amount: float volume: int turnover: float high: float low: float open: float prev_close: float timestamp: datetime = Field(default_factory=datetime.now) class StockFinancialData(BaseModel): """股票财务指标""" symbol: str pe_ratio: Optional[float] = None # 市盈率 pb_ratio: Optional[float] = None # 市净率 roe: Optional[float] = None # 净资产收益率 total_market_value: Optional[float] = None circulating_market_value: Optional[float] = None gross_profit_margin: Optional[float] = None net_profit_margin: Optional[float] = None debt_ratio: Optional[float] = None revenue_yoy: Optional[float] = None # 营收同比 profit_yoy: Optional[float] = None # 净利润同比 ================================================ FILE: backend/app/financial/providers/__init__.py ================================================ """ 数据源 Provider 模块 每个 Provider 代表一个数据源(如 Sina, Tencent, AkShare), 每个 Provider 下可以有多个 Fetcher,每个 Fetcher 对应一种数据类型。 架构: Provider (数据源) └── Fetcher (数据获取器,实现 TET Pipeline) ├── transform_query: 将标准参数转换为 Provider 特定参数 ├── extract_data: 执行实际的数据获取 └── transform_data: 将原始数据转换为标准模型 """ from .base import BaseProvider, BaseFetcher, ProviderInfo __all__ = [ "BaseProvider", "BaseFetcher", "ProviderInfo", ] ================================================ FILE: backend/app/financial/providers/base.py ================================================ """ Provider & Fetcher 基础抽象 借鉴 OpenBB 的 TET (Transform-Extract-Transform) Pipeline: 1. Transform Query: 将标准参数转换为 Provider 特定参数 2. Extract Data: 执行实际的数据获取 (HTTP/爬虫/SDK) 3. Transform Data: 将原始数据转换为标准模型 来源参考: - OpenBB: openbb_core.provider.abstract.fetcher.Fetcher - 设计文档: research/codedeepresearch/OpenBB/FinnewsHunter_improvement_plan.md """ from abc import ABC, abstractmethod from typing import TypeVar, Generic, Dict, Any, List, Type, Optional from pydantic import BaseModel from dataclasses import dataclass, field import logging # 泛型类型变量 QueryT = TypeVar("QueryT", bound=BaseModel) DataT = TypeVar("DataT", bound=BaseModel) @dataclass class ProviderInfo: """ Provider 元信息 Attributes: name: 唯一标识,如 'sina', 'akshare' display_name: 显示名称,如 '新浪财经' description: 描述 website: 官网 URL requires_credentials: 是否需要 API Key credential_keys: 需要的凭证 key 列表 priority: 降级优先级,数字越小越优先 """ name: str display_name: str description: str website: Optional[str] = None requires_credentials: bool = False credential_keys: List[str] = field(default_factory=list) priority: int = 0 # 数字越小,优先级越高 class BaseFetcher(ABC, Generic[QueryT, DataT]): """ 数据获取器基类 - 实现 TET (Transform-Extract-Transform) Pipeline 子类必须: 1. 声明 query_model 和 data_model 类属性 2. 实现 transform_query, extract_data, transform_data 三个抽象方法 Example: >>> class SinaNewsFetcher(BaseFetcher[NewsQueryParams, NewsData]): ... query_model = NewsQueryParams ... data_model = NewsData ... ... def transform_query(self, params): ... return {"url": "...", "limit": params.limit} ... ... async def extract_data(self, query): ... return await self._fetch_html(query["url"]) ... ... def transform_data(self, raw_data, query): ... return [NewsData(...) for item in raw_data] """ # 子类必须声明这两个类属性 query_model: Type[QueryT] data_model: Type[DataT] def __init__(self): self.logger = logging.getLogger( f"{self.__class__.__module__}.{self.__class__.__name__}" ) @abstractmethod def transform_query(self, params: QueryT) -> Dict[str, Any]: """ [T]ransform Query: 将标准参数转换为 Provider 特定参数 Args: params: 标准查询参数 (NewsQueryParams, StockQueryParams 等) Returns: Provider 特定的参数字典 Example: NewsQueryParams(stock_codes=['600519'], limit=10) → {'url': 'https://...', 'symbol': 'sh600519', 'count': 10} """ pass @abstractmethod async def extract_data(self, query: Dict[str, Any]) -> Any: """ [E]xtract Data: 执行实际的数据获取 可以是: - HTTP 请求 - 网页爬虫 - SDK 调用 - 数据库查询 Args: query: transform_query 返回的参数字典 Returns: 原始数据 (任意格式,由 transform_data 处理) """ pass @abstractmethod def transform_data(self, raw_data: Any, query: QueryT) -> List[DataT]: """ [T]ransform Data: 将原始数据转换为标准模型 Args: raw_data: extract_data 返回的原始数据 query: 原始查询参数 (可用于补充信息) Returns: 标准模型列表 (List[NewsData], List[StockPriceData] 等) """ pass async def fetch(self, params: QueryT) -> List[DataT]: """ 完整的 TET 执行流程 Args: params: 标准查询参数 Returns: 标准模型列表 Raises: Exception: 任何阶段失败时抛出异常 """ self.logger.info(f"Fetching with params: {params.model_dump()}") # T: Transform Query query = self.transform_query(params) self.logger.debug(f"Transformed query: {query}") # E: Extract Data raw = await self.extract_data(query) raw_count = len(raw) if isinstance(raw, (list, tuple)) else 1 self.logger.debug(f"Extracted {raw_count} raw records") # T: Transform Data results = self.transform_data(raw, params) self.logger.info(f"Transformed to {len(results)} standard records") return results def fetch_sync(self, params: QueryT) -> List[DataT]: """ 同步版本的 fetch (用于非异步环境) Args: params: 标准查询参数 Returns: 标准模型列表 """ import asyncio return asyncio.run(self.fetch(params)) class BaseProvider(ABC): """ Provider 基类 - 定义数据源能力 每个 Provider 可以有多个 Fetcher,每个 Fetcher 对应一种数据类型。 Example: >>> class SinaProvider(BaseProvider): ... @property ... def info(self): ... return ProviderInfo(name="sina", ...) ... ... @property ... def fetchers(self): ... return {"news": SinaNewsFetcher} """ @property @abstractmethod def info(self) -> ProviderInfo: """返回 Provider 元信息""" pass @property @abstractmethod def fetchers(self) -> Dict[str, Type[BaseFetcher]]: """ 返回支持的 Fetcher 映射 Returns: 格式: {data_type: FetcherClass} 例如: {'news': SinaNewsFetcher, 'stock_price': SinaStockFetcher} """ pass def get_fetcher(self, data_type: str) -> Optional[BaseFetcher]: """ 获取指定类型的 Fetcher 实例 Args: data_type: 数据类型,如 'news', 'stock_price' Returns: Fetcher 实例,如果不支持该类型则返回 None """ fetcher_cls = self.fetchers.get(data_type) if fetcher_cls: return fetcher_cls() return None def supports(self, data_type: str) -> bool: """ 检查是否支持某种数据类型 Args: data_type: 数据类型 Returns: 是否支持 """ return data_type in self.fetchers def __repr__(self) -> str: return f"<{self.__class__.__name__} name='{self.info.name}' types={list(self.fetchers.keys())}>" ================================================ FILE: backend/app/financial/providers/eastmoney/__init__.py ================================================ """ 东方财富 Provider """ from .provider import EastmoneyProvider from .fetchers.news import EastmoneyNewsFetcher __all__ = ["EastmoneyProvider", "EastmoneyNewsFetcher"] ================================================ FILE: backend/app/financial/providers/eastmoney/fetchers/__init__.py ================================================ """ 东方财富 Fetchers """ from .news import EastmoneyNewsFetcher __all__ = ["EastmoneyNewsFetcher"] ================================================ FILE: backend/app/financial/providers/eastmoney/fetchers/news.py ================================================ """ 东方财富新闻 Fetcher 基于 TET Pipeline 实现 """ import re import logging from typing import List, Dict, Any, Optional from datetime import datetime from bs4 import BeautifulSoup import requests from ...base import BaseFetcher from ....models.news import NewsQueryParams, NewsData, NewsSentiment logger = logging.getLogger(__name__) class EastmoneyNewsFetcher(BaseFetcher): """ 东方财富新闻 Fetcher 数据源: https://stock.eastmoney.com/ """ BASE_URL = "https://stock.eastmoney.com/" STOCK_URL = "https://stock.eastmoney.com/news/" SOURCE_NAME = "eastmoney" HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", } def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: """转换标准查询参数""" return { "url": self.STOCK_URL, "limit": params.limit or 20, "stock_codes": params.stock_codes, "keywords": params.keywords, } def extract_data(self, query: Dict[str, Any]) -> List[Dict[str, Any]]: """从东方财富抓取原始数据""" raw_news = [] try: # 尝试股票新闻页面,失败则尝试主页 try: response = requests.get(query["url"], headers=self.HEADERS, timeout=30) response.raise_for_status() except: response = requests.get(self.BASE_URL, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") news_links = self._extract_news_links(soup) logger.info(f"[Eastmoney] Found {len(news_links)} news links") max_fetch = min(query["limit"], 20) for link_info in news_links[:max_fetch]: try: news_item = self._fetch_news_detail(link_info) if news_item: raw_news.append(news_item) except Exception as e: logger.warning(f"[Eastmoney] Failed to fetch {link_info['url']}: {e}") continue logger.info(f"[Eastmoney] Extracted {len(raw_news)} news items") except Exception as e: logger.error(f"[Eastmoney] Extract failed: {e}") return raw_news def transform_data( self, raw_data: List[Dict[str, Any]], params: NewsQueryParams ) -> List[NewsData]: """转换原始数据为标准 NewsData 格式""" news_list = [] for item in raw_data: try: stock_codes = self._extract_stock_codes( item.get("title", "") + " " + item.get("content", "") ) if params.stock_codes: if not any(code in stock_codes for code in params.stock_codes): continue if params.keywords: text = item.get("title", "") + " " + item.get("content", "") if not any(kw in text for kw in params.keywords): continue news = NewsData( title=item.get("title", ""), content=item.get("content", ""), source=self.SOURCE_NAME, source_url=item.get("url", ""), publish_time=item.get("publish_time", datetime.now()), author=item.get("author"), stock_codes=stock_codes, sentiment=NewsSentiment.NEUTRAL, ) news_list.append(news) except Exception as e: logger.warning(f"[Eastmoney] Transform failed: {e}") continue if params.limit: news_list = news_list[:params.limit] return news_list def _extract_news_links(self, soup: BeautifulSoup) -> List[Dict[str, str]]: """从页面提取新闻链接""" news_links = [] all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 东方财富新闻URL模式 if ('eastmoney.com' in href and ('/news/' in href or '/stock/' in href or '.html' in href)) and title: if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://stock.eastmoney.com' + href elif not href.startswith('http'): href = 'https://stock.eastmoney.com/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _fetch_news_detail(self, link_info: Dict[str, str]) -> Optional[Dict[str, Any]]: """获取新闻详情""" url = link_info['url'] title = link_info['title'] try: response = requests.get(url, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") content = self._extract_content(soup) if not content: return None publish_time = self._extract_publish_time(soup) author = self._extract_author(soup) return { "title": title, "content": content, "url": url, "publish_time": publish_time, "author": author, } except Exception as e: logger.debug(f"[Eastmoney] Detail fetch failed: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'Body'}, {'id': 'ContentBody'}, {'class': 'article-content'}, {'class': 'newsContent'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) ]) if content: return self._clean_text(content) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> datetime: """提取发布时间""" try: time_elem = soup.find('div', {'class': re.compile(r'time|date')}) if not time_elem: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception: pass return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" formats = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M'] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return datetime.now() def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: elem = soup.find('div', {'class': re.compile(r'author|source')}) if not elem: elem = soup.find('span', {'class': re.compile(r'author|source')}) if elem: return elem.get_text(strip=True) except Exception: pass return None def _extract_stock_codes(self, text: str) -> List[str]: """从文本提取股票代码""" patterns = [ r'(\d{6})\.(SH|SZ|sh|sz)', r'(SH|SZ|sh|sz)(\d{6})', r'[((](\d{6})[))]', ] codes = set() for pattern in patterns: matches = re.findall(pattern, text) for match in matches: if isinstance(match, tuple): code = ''.join(match) else: code = match code = re.sub(r'[^0-9]', '', code) if len(code) == 6: codes.add(code) return list(codes) def _clean_text(self, text: str) -> str: """清理文本""" text = re.sub(r'\s+', ' ', text) return text.strip() ================================================ FILE: backend/app/financial/providers/eastmoney/provider.py ================================================ """ 东方财富 Provider """ from typing import Dict, Type from ..base import BaseProvider, BaseFetcher, ProviderInfo from .fetchers.news import EastmoneyNewsFetcher class EastmoneyProvider(BaseProvider): """ 东方财富数据源 支持的数据类型: - news: 财经新闻 """ @property def info(self) -> ProviderInfo: return ProviderInfo( name="eastmoney", display_name="东方财富", description="东方财富股票新闻 (eastmoney.com)", website="https://stock.eastmoney.com/", requires_credentials=False, priority=4 # 第四优先级 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return { "news": EastmoneyNewsFetcher, } ================================================ FILE: backend/app/financial/providers/nbd/__init__.py ================================================ """ 每日经济新闻 Provider """ from .provider import NbdProvider from .fetchers.news import NbdNewsFetcher __all__ = ["NbdProvider", "NbdNewsFetcher"] ================================================ FILE: backend/app/financial/providers/nbd/fetchers/__init__.py ================================================ """ 每日经济新闻 Fetchers """ from .news import NbdNewsFetcher __all__ = ["NbdNewsFetcher"] ================================================ FILE: backend/app/financial/providers/nbd/fetchers/news.py ================================================ """ 每日经济新闻 Fetcher 基于 TET Pipeline 实现 """ import re import logging from typing import List, Dict, Any, Optional from datetime import datetime from bs4 import BeautifulSoup import requests from ...base import BaseFetcher from ....models.news import NewsQueryParams, NewsData, NewsSentiment logger = logging.getLogger(__name__) class NbdNewsFetcher(BaseFetcher): """ 每日经济新闻 Fetcher 数据源: https://www.nbd.com.cn/ """ BASE_URL = "https://www.nbd.com.cn/" STOCK_URL = "https://www.nbd.com.cn/columns/3/" # 股市栏目 SOURCE_NAME = "nbd" HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", } def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: """转换标准查询参数""" return { "url": self.STOCK_URL, "limit": params.limit or 20, "stock_codes": params.stock_codes, "keywords": params.keywords, } def extract_data(self, query: Dict[str, Any]) -> List[Dict[str, Any]]: """从每日经济新闻抓取原始数据""" raw_news = [] try: response = requests.get(query["url"], headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") news_links = self._extract_news_links(soup) logger.info(f"[NBD] Found {len(news_links)} news links") max_fetch = min(query["limit"], 20) for link_info in news_links[:max_fetch]: try: news_item = self._fetch_news_detail(link_info) if news_item: raw_news.append(news_item) except Exception as e: logger.warning(f"[NBD] Failed to fetch {link_info['url']}: {e}") continue logger.info(f"[NBD] Extracted {len(raw_news)} news items") except Exception as e: logger.error(f"[NBD] Extract failed: {e}") return raw_news def transform_data( self, raw_data: List[Dict[str, Any]], params: NewsQueryParams ) -> List[NewsData]: """转换原始数据为标准 NewsData 格式""" news_list = [] for item in raw_data: try: stock_codes = self._extract_stock_codes( item.get("title", "") + " " + item.get("content", "") ) if params.stock_codes: if not any(code in stock_codes for code in params.stock_codes): continue if params.keywords: text = item.get("title", "") + " " + item.get("content", "") if not any(kw in text for kw in params.keywords): continue news = NewsData( title=item.get("title", ""), content=item.get("content", ""), source=self.SOURCE_NAME, source_url=item.get("url", ""), publish_time=item.get("publish_time", datetime.now()), author=item.get("author"), stock_codes=stock_codes, sentiment=NewsSentiment.NEUTRAL, ) news_list.append(news) except Exception as e: logger.warning(f"[NBD] Transform failed: {e}") continue if params.limit: news_list = news_list[:params.limit] return news_list def _extract_news_links(self, soup: BeautifulSoup) -> List[Dict[str, str]]: """从页面提取新闻链接""" news_links = [] all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) if ('/articles/' in href or '/article/' in href or '.html' in href) and title: if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.nbd.com.cn' + href elif not href.startswith('http'): href = 'https://www.nbd.com.cn/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _fetch_news_detail(self, link_info: Dict[str, str]) -> Optional[Dict[str, Any]]: """获取新闻详情""" url = link_info['url'] title = link_info['title'] try: response = requests.get(url, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") content = self._extract_content(soup) if not content: return None publish_time = self._extract_publish_time(soup) author = self._extract_author(soup) return { "title": title, "content": content, "url": url, "publish_time": publish_time, "author": author, } except Exception as e: logger.debug(f"[NBD] Detail fetch failed: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'article-body'}, {'class': 'article__body'}, {'class': 'article-text'}, {'class': 'content-article'}, {'class': 'main-content'}, {'class': 'g-article-content'}, {'class': 'article-content'}, {'id': 'contentText'}, ] for selector in content_selectors: content_div = soup.find(['div', 'article', 'section'], selector) if content_div: for tag in content_div.find_all(['script', 'style', 'iframe', 'ins']): tag.decompose() for ad in content_div.find_all(class_=re.compile(r'ad|advertisement|banner')): ad.decompose() paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) ]) if content and len(content) > 50: return self._clean_text(content) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> datetime: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date|pub')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception: pass return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" formats = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M'] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return datetime.now() def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: elem = soup.find('span', {'class': re.compile(r'author|source|editor')}) if elem: return elem.get_text(strip=True) except Exception: pass return None def _extract_stock_codes(self, text: str) -> List[str]: """从文本提取股票代码""" patterns = [ r'(\d{6})\.(SH|SZ|sh|sz)', r'(SH|SZ|sh|sz)(\d{6})', r'[((](\d{6})[))]', ] codes = set() for pattern in patterns: matches = re.findall(pattern, text) for match in matches: if isinstance(match, tuple): code = ''.join(match) else: code = match code = re.sub(r'[^0-9]', '', code) if len(code) == 6: codes.add(code) return list(codes) def _clean_text(self, text: str) -> str: """清理文本""" text = re.sub(r'\s+', ' ', text) return text.strip() ================================================ FILE: backend/app/financial/providers/nbd/provider.py ================================================ """ 每日经济新闻 Provider """ from typing import Dict, Type from ..base import BaseProvider, BaseFetcher, ProviderInfo from .fetchers.news import NbdNewsFetcher class NbdProvider(BaseProvider): """ 每日经济新闻数据源 支持的数据类型: - news: 财经新闻 """ @property def info(self) -> ProviderInfo: return ProviderInfo( name="nbd", display_name="每日经济新闻", description="每日经济新闻 (nbd.com.cn)", website="https://www.nbd.com.cn/", requires_credentials=False, priority=3 # 第三优先级 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return { "news": NbdNewsFetcher, } ================================================ FILE: backend/app/financial/providers/netease/__init__.py ================================================ """ 网易财经 Provider """ from .provider import NeteaseProvider from .fetchers.news import NeteaseNewsFetcher __all__ = ["NeteaseProvider", "NeteaseNewsFetcher"] ================================================ FILE: backend/app/financial/providers/netease/fetchers/__init__.py ================================================ """ 网易财经 Fetchers """ from .news import NeteaseNewsFetcher __all__ = ["NeteaseNewsFetcher"] ================================================ FILE: backend/app/financial/providers/netease/fetchers/news.py ================================================ """ 网易财经新闻 Fetcher 基于 TET Pipeline 实现 """ import re import logging from typing import List, Dict, Any, Optional from datetime import datetime from bs4 import BeautifulSoup import requests from ...base import BaseFetcher from ....models.news import NewsQueryParams, NewsData, NewsSentiment logger = logging.getLogger(__name__) class NeteaseNewsFetcher(BaseFetcher): """ 网易财经新闻 Fetcher 数据源: https://money.163.com/ """ BASE_URL = "https://money.163.com/" STOCK_URL = "https://money.163.com/stock/" SOURCE_NAME = "163" HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", } def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: """转换标准查询参数""" return { "url": self.STOCK_URL, "limit": params.limit or 20, "stock_codes": params.stock_codes, "keywords": params.keywords, } def extract_data(self, query: Dict[str, Any]) -> List[Dict[str, Any]]: """从网易财经抓取原始数据""" raw_news = [] try: # 尝试股票页面,失败则尝试主页 try: response = requests.get(query["url"], headers=self.HEADERS, timeout=30) response.raise_for_status() except: response = requests.get(self.BASE_URL, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") news_links = self._extract_news_links(soup) logger.info(f"[Netease] Found {len(news_links)} news links") max_fetch = min(query["limit"], 20) for link_info in news_links[:max_fetch]: try: news_item = self._fetch_news_detail(link_info) if news_item: raw_news.append(news_item) except Exception as e: logger.warning(f"[Netease] Failed to fetch {link_info['url']}: {e}") continue logger.info(f"[Netease] Extracted {len(raw_news)} news items") except Exception as e: logger.error(f"[Netease] Extract failed: {e}") return raw_news def transform_data( self, raw_data: List[Dict[str, Any]], params: NewsQueryParams ) -> List[NewsData]: """转换原始数据为标准 NewsData 格式""" news_list = [] for item in raw_data: try: stock_codes = self._extract_stock_codes( item.get("title", "") + " " + item.get("content", "") ) if params.stock_codes: if not any(code in stock_codes for code in params.stock_codes): continue if params.keywords: text = item.get("title", "") + " " + item.get("content", "") if not any(kw in text for kw in params.keywords): continue news = NewsData( title=item.get("title", ""), content=item.get("content", ""), source=self.SOURCE_NAME, source_url=item.get("url", ""), publish_time=item.get("publish_time", datetime.now()), author=item.get("author"), stock_codes=stock_codes, sentiment=NewsSentiment.NEUTRAL, ) news_list.append(news) except Exception as e: logger.warning(f"[Netease] Transform failed: {e}") continue if params.limit: news_list = news_list[:params.limit] return news_list def _extract_news_links(self, soup: BeautifulSoup) -> List[Dict[str, str]]: """从页面提取新闻链接""" news_links = [] all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 网易新闻URL模式 if ('money.163.com' in href or 'stock' in href) and title: if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://money.163.com' + href elif not href.startswith('http'): href = 'https://money.163.com/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _fetch_news_detail(self, link_info: Dict[str, str]) -> Optional[Dict[str, Any]]: """获取新闻详情""" url = link_info['url'] title = link_info['title'] try: response = requests.get(url, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") content = self._extract_content(soup) if not content: return None publish_time = self._extract_publish_time(soup) author = self._extract_author(soup) return { "title": title, "content": content, "url": url, "publish_time": publish_time, "author": author, } except Exception as e: logger.debug(f"[Netease] Detail fetch failed: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'post_text'}, {'id': 'endText'}, {'class': 'article-content'}, {'class': 'content'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) ]) if content: return self._clean_text(content) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> datetime: """提取发布时间""" try: time_elem = soup.find('div', {'class': re.compile(r'post_time|time')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception: pass return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" formats = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M'] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return datetime.now() def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: elem = soup.find('span', {'class': re.compile(r'author|source')}) if not elem: elem = soup.find('div', {'id': 'ne_article_source'}) if elem: return elem.get_text(strip=True) except Exception: pass return None def _extract_stock_codes(self, text: str) -> List[str]: """从文本提取股票代码""" patterns = [ r'(\d{6})\.(SH|SZ|sh|sz)', r'(SH|SZ|sh|sz)(\d{6})', r'[((](\d{6})[))]', ] codes = set() for pattern in patterns: matches = re.findall(pattern, text) for match in matches: if isinstance(match, tuple): code = ''.join(match) else: code = match code = re.sub(r'[^0-9]', '', code) if len(code) == 6: codes.add(code) return list(codes) def _clean_text(self, text: str) -> str: """清理文本""" text = re.sub(r'\s+', ' ', text) return text.strip() ================================================ FILE: backend/app/financial/providers/netease/provider.py ================================================ """ 网易财经 Provider """ from typing import Dict, Type from ..base import BaseProvider, BaseFetcher, ProviderInfo from .fetchers.news import NeteaseNewsFetcher class NeteaseProvider(BaseProvider): """ 网易财经数据源 支持的数据类型: - news: 财经新闻 """ @property def info(self) -> ProviderInfo: return ProviderInfo( name="163", display_name="网易财经", description="网易财经股票新闻 (money.163.com)", website="https://money.163.com/", requires_credentials=False, priority=6 # 第六优先级 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return { "news": NeteaseNewsFetcher, } ================================================ FILE: backend/app/financial/providers/sina/__init__.py ================================================ """ 新浪财经 Provider 提供: - 新闻数据 (news): SinaNewsFetcher 从 tools/sina_crawler.py 迁移而来,保留核心逻辑, 适配 TET Pipeline 架构。 """ from .provider import SinaProvider __all__ = ["SinaProvider"] ================================================ FILE: backend/app/financial/providers/sina/fetchers/__init__.py ================================================ """ 新浪财经 Fetchers """ from .news import SinaNewsFetcher __all__ = ["SinaNewsFetcher"] ================================================ FILE: backend/app/financial/providers/sina/fetchers/news.py ================================================ """ 新浪财经新闻 Fetcher 从 tools/sina_crawler.py 迁移而来,适配 TET Pipeline 架构。 主要变更: - transform_query: 将 NewsQueryParams 转换为爬虫参数 - extract_data: 执行网页爬取 - transform_data: 将原始数据转换为 NewsData 标准模型 保留原有的: - 网页解析逻辑 - 标题/内容/日期提取 - 股票代码提取 - 噪音过滤 来源: tools/sina_crawler.py (SinaCrawlerTool) """ import re import time import hashlib import logging from typing import Dict, Any, List, Optional from datetime import datetime from bs4 import BeautifulSoup from ...base import BaseFetcher from ....models.news import NewsQueryParams, NewsData logger = logging.getLogger(__name__) class SinaNewsFetcher(BaseFetcher[NewsQueryParams, NewsData]): """ 新浪财经新闻获取器 实现 TET Pipeline: - Transform Query: 将 NewsQueryParams 转换为爬虫参数 - Extract Data: 爬取网页 - Transform Data: 解析为 NewsData """ query_model = NewsQueryParams data_model = NewsData # 新浪财经最新滚动新闻页面 BASE_URL = "https://finance.sina.com.cn/roll/c/56592.shtml" SOURCE_NAME = "sina" # 请求配置 DEFAULT_TIMEOUT = 30 DEFAULT_DELAY = 0.5 DEFAULT_USER_AGENT = ( "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " "AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/120.0.0.0 Safari/537.36" ) # 噪音文本模式 NOISE_PATTERNS = [ r'^责任编辑', r'^编辑[::]', r'^来源[::]', r'^声明[::]', r'^免责声明', r'^版权', r'^copyright', r'^点击进入', r'^相关阅读', r'^延伸阅读', r'登录新浪财经APP', r'搜索【信披】', r'缩小字体', r'放大字体', r'收藏', r'微博', r'微信', r'分享', r'腾讯QQ', ] def __init__(self): super().__init__() self._session = None def _get_session(self): """获取 requests Session (延迟初始化)""" if self._session is None: import requests self._session = requests.Session() self._session.headers.update({ 'User-Agent': self.DEFAULT_USER_AGENT }) return self._session def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: """ 将标准参数转换为爬虫参数 Args: params: 标准查询参数 Returns: 爬虫参数字典 """ query = { "base_url": self.BASE_URL, "limit": params.limit, "stock_codes": params.stock_codes or [], "keywords": params.keywords or [], } # 如果有股票代码,构建股票新闻 URL if params.stock_codes: query["stock_urls"] = [] for code in params.stock_codes: symbol = self._normalize_symbol(code) stock_url = ( f"https://vip.stock.finance.sina.com.cn" f"/corp/go.php/vCB_AllNewsStock/symbol/{symbol}.phtml" ) query["stock_urls"].append(stock_url) return query async def extract_data(self, query: Dict[str, Any]) -> List[Dict]: """ 执行网页爬取 Args: query: transform_query 返回的参数 Returns: 原始新闻数据列表 """ all_news = [] limit = query["limit"] # 确定要爬取的 URL 列表 urls_to_crawl = query.get("stock_urls", [query["base_url"]]) if not urls_to_crawl: urls_to_crawl = [query["base_url"]] for url in urls_to_crawl: try: news_items = await self._crawl_page(url, limit - len(all_news)) all_news.extend(news_items) if len(all_news) >= limit: break except Exception as e: self.logger.error(f"Failed to crawl {url}: {e}") continue return all_news[:limit] async def _crawl_page(self, url: str, max_items: int) -> List[Dict]: """爬取单个页面""" import asyncio self.logger.info(f"Fetching page: {url}") # 使用 run_in_executor 执行同步请求 loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: self._fetch_page_sync(url) ) if not response: return [] # 设置编码 response.encoding = 'utf-8' soup = BeautifulSoup(response.text, 'lxml') # 查找新闻链接 news_links = self._extract_news_links(soup) self.logger.info(f"Found {len(news_links)} news links") # 爬取每条新闻详情 news_list = [] for idx, news_url in enumerate(news_links[:max_items], 1): try: self.logger.debug(f"Crawling news {idx}/{min(len(news_links), max_items)}") news_item = await self._crawl_news_detail(news_url) if news_item: news_list.append(news_item) except Exception as e: self.logger.warning(f"Failed to crawl {news_url}: {e}") continue # 请求间隔 await asyncio.sleep(self.DEFAULT_DELAY) return news_list def _fetch_page_sync(self, url: str): """同步获取页面""" try: session = self._get_session() response = session.get(url, timeout=self.DEFAULT_TIMEOUT) response.raise_for_status() return response except Exception as e: self.logger.error(f"Failed to fetch {url}: {e}") return None def _extract_news_links(self, soup: BeautifulSoup) -> List[str]: """提取新闻链接""" news_links = [] for link in soup.find_all('a', href=True): href = link.get('href', '') # 匹配新浪财经新闻 URL if 'finance.sina.com.cn' in href and ('/stock/' in href or '/roll/' in href): if href.startswith('http'): news_links.append(href) elif href.startswith('//'): news_links.append('http:' + href) # 去重 return list(set(news_links)) async def _crawl_news_detail(self, url: str) -> Optional[Dict]: """爬取新闻详情""" import asyncio loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: self._fetch_page_sync(url) ) if not response: return None try: soup = BeautifulSoup(response.content, "lxml") raw_html = response.text # 提取各字段 title = self._extract_title(soup) if not title: return None summary, keywords = self._extract_meta(soup) publish_time = self._extract_date(soup) stock_codes = self._extract_stock_codes(soup) content = self._extract_content(soup) if not content or len(content) < 50: return None return { "url": url, "title": title, "content": content, "summary": summary, "keywords": keywords, "publish_time": publish_time, "stock_codes": stock_codes, "raw_html": raw_html, } except Exception as e: self.logger.error(f"Error parsing {url}: {e}") return None def transform_data( self, raw_data: List[Dict], query: NewsQueryParams ) -> List[NewsData]: """ 将原始数据转换为 NewsData 标准模型 Args: raw_data: extract_data 返回的原始数据 query: 原始查询参数 Returns: NewsData 列表 """ results = [] for item in raw_data: try: news = NewsData( id=NewsData.generate_id(item["url"]), title=item["title"], content=item["content"], summary=item.get("summary"), source=self.SOURCE_NAME, source_url=item["url"], publish_time=item.get("publish_time") or datetime.now(), stock_codes=item.get("stock_codes", []), keywords=item.get("keywords", []), extra={"raw_html": item.get("raw_html")}, ) results.append(news) except Exception as e: self.logger.warning(f"Failed to transform item: {e}") continue return results # ========== 辅助方法(从原 sina_crawler.py 迁移)========== def _normalize_symbol(self, code: str) -> str: """标准化股票代码为新浪格式""" code = code.upper().replace("SH", "sh").replace("SZ", "sz") if code.isdigit(): if code.startswith("6"): return f"sh{code}" else: return f"sz{code}" return code.lower() def _extract_title(self, soup: BeautifulSoup) -> Optional[str]: """提取标题""" title_tag = soup.find('h1', class_='main-title') if not title_tag: title_tag = soup.find('h1') if not title_tag: title_tag = soup.find('title') if title_tag: title = title_tag.get_text().strip() title = re.sub(r'[-_].*?(新浪|财经|网)', '', title) return title.strip() return None def _extract_meta(self, soup: BeautifulSoup) -> tuple: """提取元数据(摘要和关键词)""" summary = "" keywords = [] for meta in soup.find_all('meta'): name = meta.get('name', '').lower() content = meta.get('content', '') if name == 'description': summary = content elif name == 'keywords': keywords = [kw.strip() for kw in content.split(',') if kw.strip()] return summary, keywords def _extract_date(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" for span in soup.find_all('span'): class_attr = span.get('class', []) if 'date' in class_attr or 'time-source' in class_attr: date_text = span.get_text() return self._parse_date(date_text) if span.get('id') == 'pub_date': date_text = span.get_text() return self._parse_date(date_text) return None def _parse_date(self, date_text: str) -> Optional[datetime]: """解析日期字符串""" try: date_text = date_text.strip() date_text = date_text.replace('年', '-').replace('月', '-').replace('日', '') for fmt in ['%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d']: try: return datetime.strptime(date_text.strip(), fmt) except ValueError: continue except Exception: pass return None def _extract_stock_codes(self, soup: BeautifulSoup) -> List[str]: """提取关联股票代码""" stock_codes = [] for span in soup.find_all('span'): span_id = span.get('id', '') if span_id.startswith('stock_'): code = span_id[6:].upper() if code: stock_codes.append(code) return list(set(stock_codes)) def _extract_content(self, soup: BeautifulSoup) -> str: """提取正文内容""" content_selectors = [ {'id': 'artibody'}, {'class': 'article-content'}, {'class': 'article'}, {'id': 'article'}, ] for selector in content_selectors: content_div = soup.find(['div', 'article'], selector) if content_div: # 移除噪音元素 for tag in content_div.find_all([ 'script', 'style', 'iframe', 'ins', 'select', 'input', 'button', 'form' ]): tag.decompose() for ad in content_div.find_all(class_=re.compile( r'ad|banner|share|otherContent|recommend|app-guide', re.I )): ad.decompose() # 提取文本 full_text = content_div.get_text(separator='\n', strip=True) lines = full_text.split('\n') article_parts = [] for line in lines: line = line.strip() if not line or len(line) < 2: continue if not self._is_noise_text(line): article_parts.append(line) if article_parts: return '\n'.join(article_parts) return "" def _is_noise_text(self, text: str) -> bool: """判断是否为噪音文本""" text_lower = text.lower().strip() for pattern in self.NOISE_PATTERNS: if re.match(pattern, text_lower, re.I) or re.search(pattern, text_lower, re.I): return True return False def _extract_chinese_ratio(self, text: str) -> float: """计算中文字符比例""" pattern = re.compile(r'[\u4e00-\u9fa5]+') chinese_chars = pattern.findall(text) chinese_count = sum(len(chars) for chars in chinese_chars) total_count = len(text) return chinese_count / total_count if total_count > 0 else 0 ================================================ FILE: backend/app/financial/providers/sina/provider.py ================================================ """ 新浪财经 Provider """ from typing import Dict, Type from ..base import BaseProvider, BaseFetcher, ProviderInfo from .fetchers.news import SinaNewsFetcher class SinaProvider(BaseProvider): """ 新浪财经数据源 支持的数据类型: - news: 财经新闻 """ @property def info(self) -> ProviderInfo: return ProviderInfo( name="sina", display_name="新浪财经", description="新浪财经新闻和股票数据", website="https://finance.sina.com.cn", requires_credentials=False, priority=1 # 第一优先级 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return { "news": SinaNewsFetcher, # 可扩展: "stock_price": SinaStockFetcher } ================================================ FILE: backend/app/financial/providers/tencent/__init__.py ================================================ """ 腾讯财经 Provider """ from .provider import TencentProvider from .fetchers.news import TencentNewsFetcher __all__ = ["TencentProvider", "TencentNewsFetcher"] ================================================ FILE: backend/app/financial/providers/tencent/fetchers/__init__.py ================================================ """ 腾讯财经 Fetchers """ from .news import TencentNewsFetcher __all__ = ["TencentNewsFetcher"] ================================================ FILE: backend/app/financial/providers/tencent/fetchers/news.py ================================================ """ 腾讯财经新闻 Fetcher 基于 TET Pipeline 实现: - Transform Query: 转换标准参数为腾讯财经特定参数 - Extract Data: 从腾讯财经抓取原始数据 - Transform Data: 转换为标准 NewsData 格式 """ import re import logging from typing import List, Dict, Any, Optional from datetime import datetime, timedelta from bs4 import BeautifulSoup import requests from ...base import BaseFetcher from ....models.news import NewsQueryParams, NewsData, NewsSentiment logger = logging.getLogger(__name__) class TencentNewsFetcher(BaseFetcher): """ 腾讯财经新闻 Fetcher 数据源: https://news.qq.com/ch/finance/ """ BASE_URL = "https://news.qq.com/ch/finance/" SOURCE_NAME = "tencent" # 请求配置 HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", } def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: """ 转换标准查询参数为腾讯财经特定参数 """ return { "url": self.BASE_URL, "limit": params.limit or 20, "stock_codes": params.stock_codes, "keywords": params.keywords, } def extract_data(self, query: Dict[str, Any]) -> List[Dict[str, Any]]: """ 从腾讯财经抓取原始新闻数据 """ raw_news = [] try: response = requests.get( query["url"], headers=self.HEADERS, timeout=30 ) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") news_links = self._extract_news_links(soup) logger.info(f"[Tencent] Found {len(news_links)} news links") # 限制获取数量 max_fetch = min(query["limit"], 20) for link_info in news_links[:max_fetch]: try: news_item = self._fetch_news_detail(link_info) if news_item: raw_news.append(news_item) except Exception as e: logger.warning(f"[Tencent] Failed to fetch {link_info['url']}: {e}") continue logger.info(f"[Tencent] Extracted {len(raw_news)} news items") except Exception as e: logger.error(f"[Tencent] Extract failed: {e}") return raw_news def transform_data( self, raw_data: List[Dict[str, Any]], params: NewsQueryParams ) -> List[NewsData]: """ 转换原始数据为标准 NewsData 格式 """ news_list = [] for item in raw_data: try: # 提取股票代码 stock_codes = self._extract_stock_codes( item.get("title", "") + " " + item.get("content", "") ) # 过滤:如果指定了股票代码,只保留相关新闻 if params.stock_codes: if not any(code in stock_codes for code in params.stock_codes): continue # 过滤:关键词过滤 if params.keywords: text = item.get("title", "") + " " + item.get("content", "") if not any(kw in text for kw in params.keywords): continue news = NewsData( title=item.get("title", ""), content=item.get("content", ""), source=self.SOURCE_NAME, source_url=item.get("url", ""), publish_time=item.get("publish_time", datetime.now()), author=item.get("author"), stock_codes=stock_codes, sentiment=NewsSentiment.NEUTRAL, # 默认中性 ) news_list.append(news) except Exception as e: logger.warning(f"[Tencent] Transform failed for item: {e}") continue # 应用 limit if params.limit: news_list = news_list[:params.limit] return news_list def _extract_news_links(self, soup: BeautifulSoup) -> List[Dict[str, str]]: """从页面提取新闻链接""" news_links = [] all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') # 腾讯新闻URL模式 if '/rain/a/' in href or '/omn/' in href: if not href.startswith('http'): href = 'https:' + href if href.startswith('//') else 'https://news.qq.com' + href title = link.get_text(strip=True) if title and href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _fetch_news_detail(self, link_info: Dict[str, str]) -> Optional[Dict[str, Any]]: """获取新闻详情""" url = link_info['url'] title = link_info['title'] try: response = requests.get(url, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") content = self._extract_content(soup) if not content: return None publish_time = self._extract_publish_time(soup) author = self._extract_author(soup) return { "title": title, "content": content, "url": url, "publish_time": publish_time, "author": author, } except Exception as e: logger.debug(f"[Tencent] Detail fetch failed: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'content-article'}, {'class': 'LEFT'}, {'id': 'Cnt-Main-Article-QQ'}, {'class': 'article'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) ]) if content: return self._clean_text(content) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> datetime: """提取发布时间""" try: time_selectors = [ {'class': 'a-time'}, {'class': 'article-time'}, {'class': 'time'}, ] for selector in time_selectors: time_elem = soup.find('span', selector) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) meta_time = soup.find('meta', {'property': 'article:published_time'}) if meta_time and meta_time.get('content'): return datetime.fromisoformat(meta_time['content'].replace('Z', '+00:00')) except Exception: pass return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() if '分钟前' in time_str: minutes = int(re.search(r'(\d+)', time_str).group(1)) return now - timedelta(minutes=minutes) elif '小时前' in time_str: hours = int(re.search(r'(\d+)', time_str).group(1)) return now - timedelta(hours=hours) elif '昨天' in time_str: return now - timedelta(days=1) formats = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d'] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: for selector in [{'class': 'author'}, {'class': 'source'}]: elem = soup.find('span', selector) or soup.find('a', selector) if elem: return elem.get_text(strip=True) except Exception: pass return None def _extract_stock_codes(self, text: str) -> List[str]: """从文本提取股票代码""" patterns = [ r'(\d{6})\.(SH|SZ|sh|sz)', r'(SH|SZ|sh|sz)(\d{6})', r'[((](\d{6})[))]', ] codes = set() for pattern in patterns: matches = re.findall(pattern, text) for match in matches: if isinstance(match, tuple): code = ''.join(match) else: code = match code = re.sub(r'[^0-9]', '', code) if len(code) == 6: codes.add(code) return list(codes) def _clean_text(self, text: str) -> str: """清理文本""" text = re.sub(r'\s+', ' ', text) text = text.strip() return text ================================================ FILE: backend/app/financial/providers/tencent/provider.py ================================================ """ 腾讯财经 Provider """ from typing import Dict, Type from ..base import BaseProvider, BaseFetcher, ProviderInfo from .fetchers.news import TencentNewsFetcher class TencentProvider(BaseProvider): """ 腾讯财经数据源 支持的数据类型: - news: 财经新闻 """ @property def info(self) -> ProviderInfo: return ProviderInfo( name="tencent", display_name="腾讯财经", description="腾讯财经新闻 (news.qq.com)", website="https://news.qq.com/ch/finance/", requires_credentials=False, priority=2 # 第二优先级 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return { "news": TencentNewsFetcher, } ================================================ FILE: backend/app/financial/providers/yicai/__init__.py ================================================ """ 第一财经 Provider """ from .provider import YicaiProvider from .fetchers.news import YicaiNewsFetcher __all__ = ["YicaiProvider", "YicaiNewsFetcher"] ================================================ FILE: backend/app/financial/providers/yicai/fetchers/__init__.py ================================================ """ 第一财经 Fetchers """ from .news import YicaiNewsFetcher __all__ = ["YicaiNewsFetcher"] ================================================ FILE: backend/app/financial/providers/yicai/fetchers/news.py ================================================ """ 第一财经新闻 Fetcher 基于 TET Pipeline 实现 """ import re import logging from typing import List, Dict, Any, Optional from datetime import datetime from bs4 import BeautifulSoup import requests from ...base import BaseFetcher from ....models.news import NewsQueryParams, NewsData, NewsSentiment logger = logging.getLogger(__name__) class YicaiNewsFetcher(BaseFetcher): """ 第一财经新闻 Fetcher 数据源: https://www.yicai.com/ """ BASE_URL = "https://www.yicai.com/" STOCK_URL = "https://www.yicai.com/news/gushi/" SOURCE_NAME = "yicai" HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", } def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: """转换标准查询参数""" return { "url": self.STOCK_URL, "limit": params.limit or 20, "stock_codes": params.stock_codes, "keywords": params.keywords, } def extract_data(self, query: Dict[str, Any]) -> List[Dict[str, Any]]: """从第一财经抓取原始数据""" raw_news = [] try: response = requests.get(query["url"], headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") news_links = self._extract_news_links(soup) logger.info(f"[Yicai] Found {len(news_links)} news links") max_fetch = min(query["limit"], 20) for link_info in news_links[:max_fetch]: try: news_item = self._fetch_news_detail(link_info) if news_item: raw_news.append(news_item) except Exception as e: logger.warning(f"[Yicai] Failed to fetch {link_info['url']}: {e}") continue logger.info(f"[Yicai] Extracted {len(raw_news)} news items") except Exception as e: logger.error(f"[Yicai] Extract failed: {e}") return raw_news def transform_data( self, raw_data: List[Dict[str, Any]], params: NewsQueryParams ) -> List[NewsData]: """转换原始数据为标准 NewsData 格式""" news_list = [] for item in raw_data: try: stock_codes = self._extract_stock_codes( item.get("title", "") + " " + item.get("content", "") ) if params.stock_codes: if not any(code in stock_codes for code in params.stock_codes): continue if params.keywords: text = item.get("title", "") + " " + item.get("content", "") if not any(kw in text for kw in params.keywords): continue news = NewsData( title=item.get("title", ""), content=item.get("content", ""), source=self.SOURCE_NAME, source_url=item.get("url", ""), publish_time=item.get("publish_time", datetime.now()), author=item.get("author"), stock_codes=stock_codes, sentiment=NewsSentiment.NEUTRAL, ) news_list.append(news) except Exception as e: logger.warning(f"[Yicai] Transform failed: {e}") continue if params.limit: news_list = news_list[:params.limit] return news_list def _extract_news_links(self, soup: BeautifulSoup) -> List[Dict[str, str]]: """从页面提取新闻链接""" news_links = [] all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) if ('/news/' in href or '/article/' in href) and title: if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.yicai.com' + href elif not href.startswith('http'): href = 'https://www.yicai.com/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _fetch_news_detail(self, link_info: Dict[str, str]) -> Optional[Dict[str, Any]]: """获取新闻详情""" url = link_info['url'] title = link_info['title'] try: response = requests.get(url, headers=self.HEADERS, timeout=30) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") content = self._extract_content(soup) if not content: return None publish_time = self._extract_publish_time(soup) author = self._extract_author(soup) return { "title": title, "content": content, "url": url, "publish_time": publish_time, "author": author, } except Exception as e: logger.debug(f"[Yicai] Detail fetch failed: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'm-txt'}, {'class': 'article-content'}, {'class': 'content'}, {'class': 'newsContent'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) ]) if content: return self._clean_text(content) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> datetime: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if not time_elem: time_elem = soup.find('time') if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception: pass return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" formats = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M'] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return datetime.now() def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: elem = soup.find('span', {'class': re.compile(r'author|source')}) if elem: return elem.get_text(strip=True) except Exception: pass return None def _extract_stock_codes(self, text: str) -> List[str]: """从文本提取股票代码""" patterns = [ r'(\d{6})\.(SH|SZ|sh|sz)', r'(SH|SZ|sh|sz)(\d{6})', r'[((](\d{6})[))]', ] codes = set() for pattern in patterns: matches = re.findall(pattern, text) for match in matches: if isinstance(match, tuple): code = ''.join(match) else: code = match code = re.sub(r'[^0-9]', '', code) if len(code) == 6: codes.add(code) return list(codes) def _clean_text(self, text: str) -> str: """清理文本""" text = re.sub(r'\s+', ' ', text) return text.strip() ================================================ FILE: backend/app/financial/providers/yicai/provider.py ================================================ """ 第一财经 Provider """ from typing import Dict, Type from ..base import BaseProvider, BaseFetcher, ProviderInfo from .fetchers.news import YicaiNewsFetcher class YicaiProvider(BaseProvider): """ 第一财经数据源 支持的数据类型: - news: 财经新闻 """ @property def info(self) -> ProviderInfo: return ProviderInfo( name="yicai", display_name="第一财经", description="第一财经股市新闻 (yicai.com)", website="https://www.yicai.com/", requires_credentials=False, priority=5 # 第五优先级 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return { "news": YicaiNewsFetcher, } ================================================ FILE: backend/app/financial/registry.py ================================================ """ Provider 注册中心 支持: 1. 动态注册/注销 Provider 2. 根据数据类型获取 Fetcher 3. 多 Provider 自动降级 来源参考: - OpenBB: Provider Registry 机制 - 设计文档: research/codedeepresearch/OpenBB/FinnewsHunter_improvement_plan.md """ from typing import Dict, Optional, List import logging from .providers.base import BaseProvider, BaseFetcher logger = logging.getLogger(__name__) class ProviderNotFoundError(Exception): """Provider 未找到异常""" pass class FetcherNotFoundError(Exception): """Fetcher 未找到异常""" pass class ProviderRegistry: """ Provider 注册中心 功能: 1. 注册/注销 Provider 2. 根据数据类型获取 Fetcher 3. 支持多 Provider 自动降级 Example: >>> registry = ProviderRegistry() >>> registry.register(SinaProvider()) >>> registry.register(TencentProvider()) >>> >>> # 获取 Fetcher (按优先级自动选择) >>> fetcher = registry.get_fetcher("news") >>> >>> # 指定 Provider >>> fetcher = registry.get_fetcher("news", provider="tencent") """ _instance: Optional["ProviderRegistry"] = None def __new__(cls): """单例模式""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._providers: Dict[str, BaseProvider] = {} cls._instance._priority_order: List[str] = [] cls._instance._initialized = False return cls._instance def register(self, provider: BaseProvider) -> None: """ 注册 Provider Args: provider: Provider 实例 Note: - 如果 Provider 已存在,会被替换 - 按 priority 自动排序 """ name = provider.info.name priority = provider.info.priority # 如果已存在,先移除 if name in self._providers: self._priority_order.remove(name) self._providers[name] = provider # 按优先级插入(priority 越小越靠前) inserted = False for i, existing_name in enumerate(self._priority_order): existing_priority = self._providers[existing_name].info.priority if priority < existing_priority: self._priority_order.insert(i, name) inserted = True break if not inserted: self._priority_order.append(name) logger.info( f"Registered provider: {name} " f"(priority={priority}, types={list(provider.fetchers.keys())})" ) def unregister(self, name: str) -> bool: """ 注销 Provider Args: name: Provider 名称 Returns: 是否成功注销 """ if name in self._providers: del self._providers[name] self._priority_order.remove(name) logger.info(f"Unregistered provider: {name}") return True return False def get_provider(self, name: str) -> Optional[BaseProvider]: """ 获取指定 Provider Args: name: Provider 名称 Returns: Provider 实例,如果不存在返回 None """ return self._providers.get(name) def get_fetcher( self, data_type: str, provider: Optional[str] = None ) -> BaseFetcher: """ 获取 Fetcher,支持自动降级 Args: data_type: 数据类型,如 'news', 'stock_price' provider: 可选的 Provider 名称,如果不指定则按优先级选择 Returns: BaseFetcher 实例 Raises: FetcherNotFoundError: 如果没有找到支持该数据类型的 Provider ProviderNotFoundError: 如果指定的 Provider 不存在 Example: >>> # 自动选择最高优先级的 Provider >>> fetcher = registry.get_fetcher("news") >>> >>> # 指定 Provider >>> fetcher = registry.get_fetcher("news", provider="tencent") """ # 如果指定了 Provider if provider: p = self._providers.get(provider) if not p: raise ProviderNotFoundError(f"Provider '{provider}' not found") fetcher = p.get_fetcher(data_type) if not fetcher: raise FetcherNotFoundError( f"Provider '{provider}' does not support data_type='{data_type}'" ) return fetcher # 否则按优先级选择 for p_name in self._priority_order: p = self._providers[p_name] if p.supports(data_type): fetcher = p.get_fetcher(data_type) if fetcher: logger.debug(f"Using {p_name} for {data_type}") return fetcher # 没有找到支持的 Provider available = self.get_providers_for_type(data_type) raise FetcherNotFoundError( f"No provider found for data_type='{data_type}'. " f"Available providers for this type: {available}" ) def list_providers(self) -> List[str]: """ 列出所有已注册的 Provider (按优先级排序) Returns: Provider 名称列表 """ return list(self._priority_order) def get_providers_for_type(self, data_type: str) -> List[str]: """ 获取支持指定数据类型的所有 Provider Args: data_type: 数据类型 Returns: 支持该类型的 Provider 名称列表 (按优先级排序) """ return [ name for name in self._priority_order if self._providers[name].supports(data_type) ] def get_all_data_types(self) -> List[str]: """ 获取所有支持的数据类型 Returns: 数据类型列表 """ types = set() for provider in self._providers.values(): types.update(provider.fetchers.keys()) return sorted(types) def clear(self) -> None: """清空所有注册的 Provider""" self._providers.clear() self._priority_order.clear() logger.info("Cleared all providers from registry") def __repr__(self) -> str: return f"" # 全局单例 _registry: Optional[ProviderRegistry] = None def get_registry() -> ProviderRegistry: """ 获取全局 Registry 实例 Returns: ProviderRegistry 单例 """ global _registry if _registry is None: _registry = ProviderRegistry() return _registry def reset_registry() -> ProviderRegistry: """ 重置全局 Registry (主要用于测试) Returns: 新的 ProviderRegistry 实例 """ global _registry if _registry: _registry.clear() _registry = ProviderRegistry() _registry.clear() # 确保单例也被清空 return _registry ================================================ FILE: backend/app/financial/tools.py ================================================ """ 金融数据工具 - 封装为 AgenticX BaseTool 这些工具可以直接被 Agent 调用,内部使用 Provider Registry 获取数据。 设计原则: - 继承 AgenticX BaseTool,保持与框架兼容 - 内部使用 ProviderRegistry 实现多源降级 - 返回标准化的数据格式 来源参考: - 设计文档: research/codedeepresearch/OpenBB/FinnewsHunter_improvement_plan.md """ from typing import List, Optional, Dict, Any import asyncio import logging from agenticx import BaseTool from agenticx.core import ToolMetadata, ToolCategory from .registry import get_registry, FetcherNotFoundError, ProviderNotFoundError from .models.news import NewsQueryParams, NewsData from .models.stock import StockQueryParams, StockPriceData, KlineInterval, AdjustType logger = logging.getLogger(__name__) class FinancialNewsTool(BaseTool): """ 金融新闻获取工具 支持多数据源自动切换,返回标准化的新闻数据。 Example: >>> tool = FinancialNewsTool() >>> result = await tool.aexecute(stock_codes=["600519"], limit=10) >>> print(result["data"]) # List[NewsData.model_dump()] """ def __init__(self): metadata = ToolMetadata( name="financial_news", description="获取金融新闻,支持多数据源自动切换", category=ToolCategory.DATA_ACCESS, version="1.0.0" ) super().__init__(metadata=metadata) def _setup_parameters(self): """设置工具参数(AgenticX BaseTool 要求的抽象方法)""" pass async def aexecute( self, keywords: Optional[List[str]] = None, stock_codes: Optional[List[str]] = None, limit: int = 50, provider: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 异步执行新闻获取 Args: keywords: 搜索关键词列表 stock_codes: 关联股票代码列表 limit: 返回条数 provider: 指定数据源 Returns: { "success": bool, "count": int, "provider": str, "data": List[dict] # NewsData.model_dump() } """ # 构建标准查询参数 params = NewsQueryParams( keywords=keywords, stock_codes=stock_codes, limit=limit ) try: # 获取 Fetcher(自动降级) registry = get_registry() fetcher = registry.get_fetcher("news", provider) # 执行 TET Pipeline results: List[NewsData] = await fetcher.fetch(params) # 获取实际使用的 provider 名称 provider_name = fetcher.__class__.__module__.split(".")[-3] return { "success": True, "count": len(results), "provider": provider_name, "data": [r.model_dump() for r in results] } except (FetcherNotFoundError, ProviderNotFoundError) as e: logger.error(f"Provider error: {e}") registry = get_registry() return { "success": False, "error": str(e), "available_providers": registry.get_providers_for_type("news") } except Exception as e: logger.exception(f"Unexpected error in FinancialNewsTool: {e}") return { "success": False, "error": f"Unexpected error: {e}" } def execute( self, keywords: Optional[List[str]] = None, stock_codes: Optional[List[str]] = None, limit: int = 50, provider: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 同步执行(包装异步方法) """ return asyncio.run(self.aexecute( keywords=keywords, stock_codes=stock_codes, limit=limit, provider=provider, **kwargs )) class StockPriceTool(BaseTool): """ 股票价格获取工具(K线数据) Example: >>> tool = StockPriceTool() >>> result = await tool.aexecute(symbol="600519", interval="1d", limit=30) >>> print(result["data"]) # List[StockPriceData.model_dump()] """ def __init__(self): metadata = ToolMetadata( name="stock_price", description="获取股票K线数据,支持多数据源自动切换", category=ToolCategory.DATA_ACCESS, version="1.0.0" ) super().__init__(metadata=metadata) def _setup_parameters(self): """设置工具参数(AgenticX BaseTool 要求的抽象方法)""" pass async def aexecute( self, symbol: str, interval: str = "1d", limit: int = 90, adjust: str = "qfq", provider: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 异步执行价格获取 Args: symbol: 股票代码 interval: K线周期 limit: 返回条数 adjust: 复权类型 provider: 指定数据源 Returns: { "success": bool, "symbol": str, "count": int, "provider": str, "data": List[dict] # StockPriceData.model_dump() } """ try: params = StockQueryParams( symbol=symbol, interval=KlineInterval(interval), limit=limit, adjust=AdjustType(adjust) ) except ValueError as e: return { "success": False, "error": f"Invalid parameter: {e}" } try: registry = get_registry() fetcher = registry.get_fetcher("stock_price", provider) results: List[StockPriceData] = await fetcher.fetch(params) provider_name = fetcher.__class__.__module__.split(".")[-3] return { "success": True, "symbol": symbol, "count": len(results), "provider": provider_name, "data": [r.model_dump() for r in results] } except (FetcherNotFoundError, ProviderNotFoundError) as e: logger.error(f"Provider error: {e}") registry = get_registry() return { "success": False, "error": str(e), "available_providers": registry.get_providers_for_type("stock_price") } except Exception as e: logger.exception(f"Unexpected error in StockPriceTool: {e}") return { "success": False, "error": f"Unexpected error: {e}" } def execute( self, symbol: str, interval: str = "1d", limit: int = 90, adjust: str = "qfq", provider: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """同步执行""" return asyncio.run(self.aexecute( symbol=symbol, interval=interval, limit=limit, adjust=adjust, provider=provider, **kwargs )) # 便捷函数:自动注册默认 Provider def setup_default_providers(): """ 注册默认的 Provider 在应用启动时调用,确保 Registry 中有可用的 Provider。 当前支持的数据源(按优先级排序): 1. sina - 新浪财经 2. tencent - 腾讯财经 3. nbd - 每日经济新闻 4. eastmoney - 东方财富 5. yicai - 第一财经 6. 163 - 网易财经 """ from .providers.sina import SinaProvider from .providers.tencent import TencentProvider from .providers.nbd import NbdProvider from .providers.eastmoney import EastmoneyProvider from .providers.yicai import YicaiProvider from .providers.netease import NeteaseProvider registry = get_registry() # 定义所有 Provider(按优先级顺序) providers = [ ("sina", SinaProvider), ("tencent", TencentProvider), ("nbd", NbdProvider), ("eastmoney", EastmoneyProvider), ("yicai", YicaiProvider), ("163", NeteaseProvider), ] # 注册所有 Provider for name, provider_class in providers: if name not in registry.list_providers(): try: registry.register(provider_class()) logger.debug(f"Registered provider: {name}") except Exception as e: logger.warning(f"Failed to register provider {name}: {e}") logger.info(f"Registered {len(registry.list_providers())} providers: {registry.list_providers()}") ================================================ FILE: backend/app/knowledge/README.md ================================================ # 知识图谱模块 ## 📊 概述 知识图谱模块为每只股票构建动态的知识图谱,用于智能化的新闻检索和分析。 ## 🎯 核心功能 ### 1. 多维度知识建模 为每家公司建立包含以下信息的知识图谱: - **名称变体**:公司简称、别名、全称 - **业务线**:主营业务、新增业务、已停止业务 - **行业归属**:一级行业、二级行业、细分领域 - **产品服务**:主要产品和服务 - **关联概念**:涉及的热点概念(AI大模型、云计算等) - **检索关键词**:优化检索效果的关键词 ### 2. 智能并发检索 基于知识图谱生成多样化的检索查询,并发调用搜索API: ``` 示例:彩讯股份 (300634) 生成的查询组合: 1. "彩讯股份 300634" 2. "彩讯 股票" 3. "彩讯股份 运营商增值服务" 4. "彩讯 AI大模型应用" 5. "彩讯科技 云计算" 6. ...(最多10条并发查询) ``` ### 3. 动态图谱更新 - **构建时机**:首次定向爬取时自动构建 - **数据来源**: - akshare:基础信息(行业、市值、主营业务) - LLM推理:名称变体、业务细分 - 新闻分析:业务变化、新概念 - 文档解析:深度信息(年报、公告) - **更新机制**:每次定向爬取后自动更新 ## 🏗️ 架构设计 ### 图谱结构 ``` (Company) 公司节点 ├─ HAS_VARIANT ─> (NameVariant) 名称变体 ├─ OPERATES_IN ─> (Business) 业务线 ├─ BELONGS_TO ─> (Industry) 行业 ├─ PROVIDES ─> (Product) 产品 ├─ RELATES_TO ─> (Keyword) 关键词 └─ INVOLVES ─> (Concept) 概念 ``` ### 核心组件 1. **graph_models.py** - 数据模型定义 2. **graph_service.py** - 图谱CRUD服务 3. **knowledge_extractor.py** - 知识提取Agent 4. **parallel_search.py** - 并发检索策略 ## 🚀 使用方法 ### 1. 启动 Neo4j ```bash cd deploy docker-compose -f docker-compose.dev.yml up -d neo4j ``` ### 2. 初始化图谱 ```bash cd backend python init_knowledge_graph.py ``` ### 3. API 调用 #### 查询图谱 ```bash GET /api/v1/knowledge-graph/{stock_code} ``` #### 构建图谱 ```bash POST /api/v1/knowledge-graph/{stock_code}/build { "force_rebuild": false } ``` #### 更新图谱 ```bash POST /api/v1/knowledge-graph/{stock_code}/update { "update_from_news": true, "news_limit": 20 } ``` #### 删除图谱 ```bash DELETE /api/v1/knowledge-graph/{stock_code} ``` ### 4. 自动集成 定向爬取时自动使用知识图谱: 1. **检查图谱**:如果不存在,自动从 akshare + LLM 构建 2. **并发检索**:基于图谱生成的多个关键词并发搜索 3. **更新图谱**:爬取完成后,从新闻中提取新信息更新图谱 ## 📈 效果对比 ### 传统单关键词检索 ```python query = "彩讯股份 股票 300634" results = search(query) # ~20-30条 ``` ### 基于知识图谱的并发检索 ```python queries = [ "彩讯股份 300634", "彩讯 运营商增值服务", "彩讯股份 AI大模型应用", "彩讯科技 云计算", ... ] results = parallel_search(queries) # ~100-200条,去重后70-130条 ``` **召回率提升:3-5倍** ## 🔧 配置 环境变量: ```bash NEO4J_URI=bolt://localhost:7687 NEO4J_USER=neo4j NEO4J_PASSWORD=finnews_neo4j_password ``` ## 📊 监控 访问 Neo4j 浏览器: - URL: http://localhost:7474 - 用户名: neo4j - 密码: finnews_neo4j_password 示例查询: ```cypher // 查看所有公司 MATCH (c:Company) RETURN c // 查看公司的完整图谱 MATCH (c:Company {stock_code: 'SZ300634'})-[r]->(n) RETURN c, r, n // 查看业务线 MATCH (c:Company)-[:OPERATES_IN]->(b:Business) WHERE b.status = 'active' RETURN c.stock_name, b.business_name, b.status ``` ## ⚠️ 注意事项 1. **LLM成本**:图谱构建和更新会调用LLM,注意API成本 2. **并发限制**:并发检索默认5个worker,可根据API限制调整 3. **图谱更新**:建议每次定向爬取后自动更新,保持图谱时效性 4. **数据质量**:LLM提取的信息需要人工review,建议提供review接口 ================================================ FILE: backend/app/knowledge/__init__.py ================================================ """ 知识图谱模块 """ from .graph_models import ( CompanyNode, NameVariantNode, BusinessNode, IndustryNode, ProductNode, KeywordNode, ConceptNode, CompanyKnowledgeGraph, SearchKeywordSet, NodeType, RelationType ) from .graph_service import KnowledgeGraphService __all__ = [ "CompanyNode", "NameVariantNode", "BusinessNode", "IndustryNode", "ProductNode", "KeywordNode", "ConceptNode", "CompanyKnowledgeGraph", "SearchKeywordSet", "NodeType", "RelationType", "KnowledgeGraphService" ] ================================================ FILE: backend/app/knowledge/graph_models.py ================================================ """ 知识图谱数据模型 定义公司知识图谱的节点和关系结构 """ from typing import List, Dict, Any, Optional from pydantic import BaseModel, Field from datetime import datetime from enum import Enum class NodeType(str, Enum): """节点类型枚举""" COMPANY = "Company" # 公司 NAME_VARIANT = "NameVariant" # 名称变体 BUSINESS = "Business" # 业务线 INDUSTRY = "Industry" # 行业 PRODUCT = "Product" # 产品/服务 KEYWORD = "Keyword" # 检索关键词 CONCEPT = "Concept" # 概念/主题 PARTNER = "Partner" # 合作伙伴 class RelationType(str, Enum): """关系类型枚举""" HAS_VARIANT = "HAS_VARIANT" # 有变体 OPERATES_IN = "OPERATES_IN" # 运营于(业务线) BELONGS_TO = "BELONGS_TO" # 属于(行业) PROVIDES = "PROVIDES" # 提供(产品) RELATES_TO = "RELATES_TO" # 关联(关键词) INVOLVES = "INVOLVES" # 涉及(概念) COOPERATES_WITH = "COOPERATES_WITH" # 合作(伙伴) UPSTREAM = "UPSTREAM" # 上游 DOWNSTREAM = "DOWNSTREAM" # 下游 class CompanyNode(BaseModel): """公司节点""" stock_code: str = Field(description="股票代码(如 SZ300634)") stock_name: str = Field(description="股票全称(如 彩讯股份)") short_code: str = Field(description="纯数字代码(如 300634)") industry: Optional[str] = Field(default=None, description="所属行业") sector: Optional[str] = Field(default=None, description="所属板块") market_cap: Optional[float] = Field(default=None, description="市值") listed_date: Optional[str] = Field(default=None, description="上市日期") created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) class NameVariantNode(BaseModel): """名称变体节点""" variant: str = Field(description="变体名称(如 彩讯、彩讯科技)") variant_type: str = Field(description="变体类型: abbreviation, alias, full_name") created_at: datetime = Field(default_factory=datetime.utcnow) class BusinessNode(BaseModel): """业务线节点""" business_name: str = Field(description="业务名称") business_type: str = Field(description="业务类型: main, new, stopped") description: Optional[str] = Field(default=None, description="业务描述") start_date: Optional[str] = Field(default=None, description="开始日期") end_date: Optional[str] = Field(default=None, description="结束日期(如果停止)") status: str = Field(default="active", description="状态: active, stopped, planned") created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) class IndustryNode(BaseModel): """行业节点""" industry_name: str = Field(description="行业名称") industry_code: Optional[str] = Field(default=None, description="行业代码") level: int = Field(default=1, description="层级: 1=一级行业, 2=二级行业") created_at: datetime = Field(default_factory=datetime.utcnow) class ProductNode(BaseModel): """产品/服务节点""" product_name: str = Field(description="产品名称") product_type: str = Field(description="产品类型: software, hardware, service") description: Optional[str] = Field(default=None, description="产品描述") created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) class KeywordNode(BaseModel): """检索关键词节点""" keyword: str = Field(description="关键词") keyword_type: str = Field(description="类型: business, product, industry, general") weight: float = Field(default=1.0, description="权重(检索时的重要性)") created_at: datetime = Field(default_factory=datetime.utcnow) class ConceptNode(BaseModel): """概念/主题节点""" concept_name: str = Field(description="概念名称(如 AI大模型、元宇宙)") description: Optional[str] = Field(default=None, description="概念描述") hot_level: int = Field(default=0, description="热度等级 0-10") created_at: datetime = Field(default_factory=datetime.utcnow) class CompanyKnowledgeGraph(BaseModel): """公司知识图谱完整结构(用于导入导出)""" company: CompanyNode name_variants: List[NameVariantNode] = Field(default_factory=list) businesses: List[BusinessNode] = Field(default_factory=list) industries: List[IndustryNode] = Field(default_factory=list) products: List[ProductNode] = Field(default_factory=list) keywords: List[KeywordNode] = Field(default_factory=list) concepts: List[ConceptNode] = Field(default_factory=list) class SearchKeywordSet(BaseModel): """检索关键词集合(用于定向爬取)""" stock_code: str stock_name: str # 名称相关 name_keywords: List[str] = Field(default_factory=list, description="名称变体") # 业务相关 business_keywords: List[str] = Field(default_factory=list, description="业务线关键词") # 行业相关 industry_keywords: List[str] = Field(default_factory=list, description="行业关键词") # 产品相关 product_keywords: List[str] = Field(default_factory=list, description="产品关键词") # 概念相关 concept_keywords: List[str] = Field(default_factory=list, description="概念关键词") # 组合查询 combined_queries: List[str] = Field(default_factory=list, description="预组合的查询串") def get_all_keywords(self) -> List[str]: """获取所有关键词(去重)""" all_kw = ( self.name_keywords + self.business_keywords + self.industry_keywords + self.product_keywords + self.concept_keywords ) return list(set(all_kw)) def generate_search_queries(self, max_queries: int = 10) -> List[str]: """ 生成多样化的搜索查询组合 Args: max_queries: 最大查询数量 Returns: 查询字符串列表 """ queries = [] # 1. 核心查询:股票名称 + 股票代码 if self.name_keywords: queries.append(f"{self.stock_name} {self.stock_code}") queries.append(f"{self.name_keywords[0]} 股票") # 2. 业务线查询 for business in self.business_keywords[:3]: # 最多3个业务线 queries.append(f"{self.stock_name} {business}") if len(self.name_keywords) > 1: queries.append(f"{self.name_keywords[0]} {business}") # 3. 概念查询 for concept in self.concept_keywords[:2]: # 最多2个概念 queries.append(f"{self.stock_name} {concept}") # 4. 产品查询 for product in self.product_keywords[:2]: # 最多2个产品 queries.append(f"{self.stock_name} {product}") # 5. 使用预组合查询 queries.extend(self.combined_queries) # 去重并限制数量 unique_queries = list(dict.fromkeys(queries)) # 保持顺序的去重 return unique_queries[:max_queries] ================================================ FILE: backend/app/knowledge/graph_service.py ================================================ """ 知识图谱服务 提供公司知识图谱的创建、查询、更新操作 """ import logging from typing import List, Dict, Any, Optional from datetime import datetime from ..core.neo4j_client import get_neo4j_client from .graph_models import ( CompanyNode, NameVariantNode, BusinessNode, IndustryNode, ProductNode, KeywordNode, ConceptNode, CompanyKnowledgeGraph, SearchKeywordSet, NodeType, RelationType ) logger = logging.getLogger(__name__) class KnowledgeGraphService: """知识图谱服务""" def __init__(self): self.neo4j = get_neo4j_client() self._ensure_constraints() def _ensure_constraints(self): """确保数据库约束和索引存在""" constraints = [ # 公司节点唯一约束 "CREATE CONSTRAINT company_code IF NOT EXISTS FOR (c:Company) REQUIRE c.stock_code IS UNIQUE", # 索引加速查询 "CREATE INDEX company_name IF NOT EXISTS FOR (c:Company) ON (c.stock_name)", "CREATE INDEX business_name IF NOT EXISTS FOR (b:Business) ON (b.business_name)", "CREATE INDEX keyword_text IF NOT EXISTS FOR (k:Keyword) ON (k.keyword)", ] for constraint in constraints: try: self.neo4j.execute_write(constraint) except Exception as e: # 约束可能已存在,忽略错误 logger.debug(f"Constraint creation skipped: {e}") # ============ 公司节点操作 ============ def create_or_update_company(self, company: CompanyNode) -> bool: """ 创建或更新公司节点 Args: company: 公司节点数据 Returns: 是否成功 """ query = """ MERGE (c:Company {stock_code: $stock_code}) SET c.stock_name = $stock_name, c.short_code = $short_code, c.industry = $industry, c.sector = $sector, c.market_cap = $market_cap, c.listed_date = $listed_date, c.updated_at = datetime(), c.created_at = coalesce(c.created_at, datetime()) RETURN c """ params = company.model_dump() params['created_at'] = company.created_at.isoformat() params['updated_at'] = datetime.utcnow().isoformat() try: self.neo4j.execute_write(query, params) logger.info(f"✅ 公司节点已创建/更新: {company.stock_name}({company.stock_code})") return True except Exception as e: logger.error(f"❌ 公司节点创建失败: {e}") return False def get_company(self, stock_code: str) -> Optional[Dict[str, Any]]: """获取公司节点""" query = """ MATCH (c:Company {stock_code: $stock_code}) RETURN c """ results = self.neo4j.execute_query(query, {"stock_code": stock_code}) return results[0]['c'] if results else None # ============ 名称变体操作 ============ def add_name_variants( self, stock_code: str, variants: List[NameVariantNode] ) -> bool: """ 添加名称变体 Args: stock_code: 股票代码 variants: 名称变体列表 Returns: 是否成功 """ for variant in variants: query = """ MATCH (c:Company {stock_code: $stock_code}) MERGE (v:NameVariant {variant: $variant}) SET v.variant_type = $variant_type, v.created_at = $created_at MERGE (c)-[r:HAS_VARIANT]->(v) RETURN v """ params = { "stock_code": stock_code, "variant": variant.variant, "variant_type": variant.variant_type, "created_at": variant.created_at.isoformat() } try: self.neo4j.execute_write(query, params) except Exception as e: logger.error(f"添加名称变体失败 {variant.variant}: {e}") return False logger.info(f"✅ 已添加 {len(variants)} 个名称变体") return True # ============ 业务线操作 ============ def add_business( self, stock_code: str, business: BusinessNode ) -> bool: """添加业务线""" query = """ MATCH (c:Company {stock_code: $stock_code}) MERGE (b:Business {business_name: $business_name}) SET b.business_type = $business_type, b.description = $description, b.start_date = $start_date, b.end_date = $end_date, b.status = $status, b.updated_at = datetime(), b.created_at = coalesce(b.created_at, datetime()) MERGE (c)-[r:OPERATES_IN]->(b) RETURN b """ params = business.model_dump() params['stock_code'] = stock_code try: self.neo4j.execute_write(query, params) logger.info(f"✅ 业务线已添加: {business.business_name}") return True except Exception as e: logger.error(f"❌ 业务线添加失败: {e}") return False def stop_business( self, stock_code: str, business_name: str, end_date: str = None ) -> bool: """停止业务线""" query = """ MATCH (c:Company {stock_code: $stock_code})-[:OPERATES_IN]->(b:Business {business_name: $business_name}) SET b.status = 'stopped', b.end_date = $end_date, b.updated_at = datetime() RETURN b """ params = { "stock_code": stock_code, "business_name": business_name, "end_date": end_date or datetime.utcnow().strftime("%Y-%m-%d") } try: self.neo4j.execute_write(query, params) logger.info(f"✅ 业务线已停止: {business_name}") return True except Exception as e: logger.error(f"❌ 业务线停止失败: {e}") return False # ============ 关键词操作 ============ def add_keywords( self, stock_code: str, keywords: List[KeywordNode], relation_type: str = "RELATES_TO" ) -> bool: """添加检索关键词""" for keyword in keywords: query = """ MATCH (c:Company {stock_code: $stock_code}) MERGE (k:Keyword {keyword: $keyword}) SET k.keyword_type = $keyword_type, k.weight = $weight, k.created_at = $created_at MERGE (c)-[r:RELATES_TO]->(k) RETURN k """ params = { "stock_code": stock_code, "keyword": keyword.keyword, "keyword_type": keyword.keyword_type, "weight": keyword.weight, "created_at": keyword.created_at.isoformat() } try: self.neo4j.execute_write(query, params) except Exception as e: logger.error(f"添加关键词失败 {keyword.keyword}: {e}") return False logger.info(f"✅ 已添加 {len(keywords)} 个关键词") return True # ============ 概念操作 ============ def add_concepts( self, stock_code: str, concepts: List[ConceptNode] ) -> bool: """添加概念/主题""" for concept in concepts: query = """ MATCH (c:Company {stock_code: $stock_code}) MERGE (con:Concept {concept_name: $concept_name}) SET con.description = $description, con.hot_level = $hot_level, con.created_at = $created_at MERGE (c)-[r:INVOLVES]->(con) RETURN con """ params = { "stock_code": stock_code, "concept_name": concept.concept_name, "description": concept.description, "hot_level": concept.hot_level, "created_at": concept.created_at.isoformat() } try: self.neo4j.execute_write(query, params) except Exception as e: logger.error(f"添加概念失败 {concept.concept_name}: {e}") return False logger.info(f"✅ 已添加 {len(concepts)} 个概念") return True # ============ 完整图谱操作 ============ def build_company_graph(self, graph: CompanyKnowledgeGraph) -> bool: """ 构建完整的公司知识图谱 Args: graph: 公司知识图谱数据 Returns: 是否成功 """ try: # 1. 创建公司节点 self.create_or_update_company(graph.company) # 2. 添加名称变体 if graph.name_variants: self.add_name_variants(graph.company.stock_code, graph.name_variants) # 3. 添加业务线 for business in graph.businesses: self.add_business(graph.company.stock_code, business) # 4. 添加行业 for industry in graph.industries: self._add_industry(graph.company.stock_code, industry) # 5. 添加产品 for product in graph.products: self._add_product(graph.company.stock_code, product) # 6. 添加关键词 if graph.keywords: self.add_keywords(graph.company.stock_code, graph.keywords) # 7. 添加概念 if graph.concepts: self.add_concepts(graph.company.stock_code, graph.concepts) logger.info(f"✅ 知识图谱构建完成: {graph.company.stock_name}") return True except Exception as e: logger.error(f"❌ 知识图谱构建失败: {e}") return False def _add_industry(self, stock_code: str, industry: IndustryNode) -> bool: """添加行业节点(内部方法)""" query = """ MATCH (c:Company {stock_code: $stock_code}) MERGE (i:Industry {industry_name: $industry_name}) SET i.industry_code = $industry_code, i.level = $level, i.created_at = $created_at MERGE (c)-[r:BELONGS_TO]->(i) RETURN i """ params = industry.model_dump() params['stock_code'] = stock_code try: self.neo4j.execute_write(query, params) return True except Exception as e: logger.error(f"行业添加失败: {e}") return False def _add_product(self, stock_code: str, product: ProductNode) -> bool: """添加产品节点(内部方法)""" query = """ MATCH (c:Company {stock_code: $stock_code}) MERGE (p:Product {product_name: $product_name}) SET p.product_type = $product_type, p.description = $description, p.updated_at = datetime(), p.created_at = coalesce(p.created_at, datetime()) MERGE (c)-[r:PROVIDES]->(p) RETURN p """ params = product.model_dump() params['stock_code'] = stock_code try: self.neo4j.execute_write(query, params) return True except Exception as e: logger.error(f"产品添加失败: {e}") return False # ============ 查询操作 ============ def get_company_graph(self, stock_code: str) -> Optional[CompanyKnowledgeGraph]: """ 获取完整的公司知识图谱 Args: stock_code: 股票代码 Returns: 公司知识图谱或None """ # 查询公司及其所有关联节点 query = """ MATCH (c:Company {stock_code: $stock_code}) OPTIONAL MATCH (c)-[:HAS_VARIANT]->(v:NameVariant) OPTIONAL MATCH (c)-[:OPERATES_IN]->(b:Business) OPTIONAL MATCH (c)-[:BELONGS_TO]->(i:Industry) OPTIONAL MATCH (c)-[:PROVIDES]->(p:Product) OPTIONAL MATCH (c)-[:RELATES_TO]->(k:Keyword) OPTIONAL MATCH (c)-[:INVOLVES]->(con:Concept) RETURN c, collect(DISTINCT v) as variants, collect(DISTINCT b) as businesses, collect(DISTINCT i) as industries, collect(DISTINCT p) as products, collect(DISTINCT k) as keywords, collect(DISTINCT con) as concepts """ try: results = self.neo4j.execute_query(query, {"stock_code": stock_code}) if not results or not results[0]['c']: return None data = results[0] company_data = dict(data['c']) # 构建完整图谱 graph = CompanyKnowledgeGraph( company=CompanyNode(**company_data), name_variants=[NameVariantNode(**dict(v)) for v in data['variants'] if v], businesses=[BusinessNode(**dict(b)) for b in data['businesses'] if b], industries=[IndustryNode(**dict(i)) for i in data['industries'] if i], products=[ProductNode(**dict(p)) for p in data['products'] if p], keywords=[KeywordNode(**dict(k)) for k in data['keywords'] if k], concepts=[ConceptNode(**dict(c)) for c in data['concepts'] if c] ) return graph except Exception as e: logger.error(f"查询公司图谱失败: {e}") return None def get_search_keywords(self, stock_code: str) -> Optional[SearchKeywordSet]: """ 获取用于检索的关键词集合 Args: stock_code: 股票代码 Returns: 检索关键词集合 """ graph = self.get_company_graph(stock_code) if not graph: return None # 构建检索关键词集合 keyword_set = SearchKeywordSet( stock_code=stock_code, stock_name=graph.company.stock_name, name_keywords=[v.variant for v in graph.name_variants], business_keywords=[b.business_name for b in graph.businesses if b.status == "active"], industry_keywords=[i.industry_name for i in graph.industries], product_keywords=[p.product_name for p in graph.products], concept_keywords=[c.concept_name for c in graph.concepts] ) # 生成组合查询 keyword_set.combined_queries = keyword_set.generate_search_queries(max_queries=10) return keyword_set # ============ 图谱更新 ============ def update_from_news( self, stock_code: str, news_content: str, extracted_info: Dict[str, Any] ) -> bool: """ 根据新闻更新图谱 Args: stock_code: 股票代码 news_content: 新闻内容 extracted_info: 提取的信息(由 LLM 提取) { "new_businesses": [...], "stopped_businesses": [...], "new_products": [...], "new_concepts": [...] } Returns: 是否成功 """ try: # 添加新业务线 for biz_name in extracted_info.get("new_businesses", []): business = BusinessNode( business_name=biz_name, business_type="new", status="active", start_date=datetime.utcnow().strftime("%Y-%m-%d") ) self.add_business(stock_code, business) # 停止业务线 for biz_name in extracted_info.get("stopped_businesses", []): self.stop_business(stock_code, biz_name) # 添加新产品 for prod_name in extracted_info.get("new_products", []): product = ProductNode( product_name=prod_name, product_type="service" ) self._add_product(stock_code, product) # 添加新概念 for concept_name in extracted_info.get("new_concepts", []): concept = ConceptNode( concept_name=concept_name, hot_level=5 ) self.add_concepts(stock_code, [concept]) logger.info(f"✅ 图谱已更新(基于新闻)") return True except Exception as e: logger.error(f"❌ 图谱更新失败: {e}") return False # ============ 统计和管理 ============ def get_graph_stats(self, stock_code: str) -> Dict[str, int]: """获取图谱统计信息""" query = """ MATCH (c:Company {stock_code: $stock_code}) OPTIONAL MATCH (c)-[:HAS_VARIANT]->(v:NameVariant) OPTIONAL MATCH (c)-[:OPERATES_IN]->(b:Business) OPTIONAL MATCH (c)-[:BELONGS_TO]->(i:Industry) OPTIONAL MATCH (c)-[:PROVIDES]->(p:Product) OPTIONAL MATCH (c)-[:RELATES_TO]->(k:Keyword) OPTIONAL MATCH (c)-[:INVOLVES]->(con:Concept) RETURN count(DISTINCT v) as variants_count, count(DISTINCT b) as businesses_count, count(DISTINCT i) as industries_count, count(DISTINCT p) as products_count, count(DISTINCT k) as keywords_count, count(DISTINCT con) as concepts_count """ try: results = self.neo4j.execute_query(query, {"stock_code": stock_code}) if results: return dict(results[0]) return {} except Exception as e: logger.error(f"查询图谱统计失败: {e}") return {} def delete_company_graph(self, stock_code: str) -> bool: """删除公司及其所有关联节点""" query = """ MATCH (c:Company {stock_code: $stock_code}) OPTIONAL MATCH (c)-[r]->(n) DETACH DELETE c, n """ try: self.neo4j.execute_write(query, {"stock_code": stock_code}) logger.info(f"✅ 公司图谱已删除: {stock_code}") return True except Exception as e: logger.error(f"❌ 图谱删除失败: {e}") return False def list_all_companies(self) -> List[Dict[str, str]]: """列出所有公司""" query = """ MATCH (c:Company) RETURN c.stock_code as stock_code, c.stock_name as stock_name, c.industry as industry ORDER BY c.stock_code """ try: return self.neo4j.execute_query(query) except Exception as e: logger.error(f"查询公司列表失败: {e}") return [] # 便捷函数 def get_graph_service() -> KnowledgeGraphService: """获取知识图谱服务实例""" return KnowledgeGraphService() ================================================ FILE: backend/app/knowledge/knowledge_extractor.py ================================================ """ 知识提取器 从多种数据源提取公司知识并构建图谱 """ import logging import json from typing import List, Dict, Any, Optional from datetime import datetime from agenticx import Agent from ..services.llm_service import get_llm_provider from .graph_models import ( CompanyNode, NameVariantNode, BusinessNode, IndustryNode, ProductNode, KeywordNode, ConceptNode, CompanyKnowledgeGraph ) logger = logging.getLogger(__name__) class KnowledgeExtractorAgent(Agent): """ 知识提取智能体 从多种数据源提取公司信息并构建知识图谱 """ def __init__(self, llm_provider=None, organization_id: str = "finnews"): super().__init__( name="KnowledgeExtractor", role="知识提取专家", goal="从多种数据源提取公司信息,构建全面的知识图谱", backstory="""你是一位专业的企业分析师和知识工程师。 你擅长从各类数据源(财务数据、新闻、公告、研报)中提取关键信息, 识别公司的业务线、产品、行业归属、关联概念等, 并将这些信息结构化为知识图谱,用于后续的智能检索和分析。""", organization_id=organization_id ) if llm_provider is None: llm_provider = get_llm_provider() object.__setattr__(self, '_llm_provider', llm_provider) logger.info(f"Initialized {self.name} agent") async def extract_from_akshare( self, stock_code: str, stock_name: str, stock_info: Dict[str, Any] ) -> CompanyKnowledgeGraph: """ 从 akshare 数据提取基础信息 Args: stock_code: 股票代码 stock_name: 股票名称 stock_info: akshare 返回的股票信息 Returns: 公司知识图谱 """ # 获取当前时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") # 提取纯数字代码 short_code = stock_code if stock_code.startswith("SH") or stock_code.startswith("SZ"): short_code = stock_code[2:] # 创建公司节点 company = CompanyNode( stock_code=stock_code, stock_name=stock_name, short_code=short_code, industry=stock_info.get("industry"), sector=stock_info.get("sector"), market_cap=stock_info.get("market_cap"), listed_date=stock_info.get("listed_date") ) # 生成名称变体(通过 LLM 推理) name_variants_prompt = f"""请为以下公司生成可能的名称变体(简称、别名等): 【当前时间】 {current_time} 【公司信息】 股票代码: {stock_code} 公司全称: {stock_name} 所属行业: {stock_info.get('industry', '未知')} 请以JSON格式返回名称变体列表,每个变体包含: - variant: 变体名称 - variant_type: 类型(abbreviation=简称, alias=别名, full_name=全称) 示例: ```json [ {{"variant": "彩讯", "variant_type": "abbreviation"}}, {{"variant": "彩讯科技", "variant_type": "alias"}}, {{"variant": "{stock_name}", "variant_type": "full_name"}} ] ``` 只返回JSON,不要其他解释。""" try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": name_variants_prompt} ]) content = response.content if hasattr(response, 'content') else str(response) # 提取JSON import re json_match = re.search(r'\[.*\]', content, re.DOTALL) if json_match: variants_data = json.loads(json_match.group()) name_variants = [NameVariantNode(**v) for v in variants_data] else: # 默认变体 name_variants = [ NameVariantNode(variant=stock_name, variant_type="full_name"), NameVariantNode(variant=stock_name[:2], variant_type="abbreviation") ] logger.warning("LLM 未返回有效JSON,使用默认变体") except Exception as e: logger.error(f"名称变体提取失败: {e}") name_variants = [ NameVariantNode(variant=stock_name, variant_type="full_name") ] # 生成业务线(通过 LLM 推理 + akshare 数据) business_prompt = f"""请分析以下公司的主营业务线: 【当前时间】 {current_time} 【公司信息】 股票代码: {stock_code} 公司名称: {stock_name} 所属行业: {stock_info.get('industry', '未知')} 主营业务: {stock_info.get('main_business', '未知')} 请以JSON格式返回业务线列表,每个业务包含: - business_name: 业务名称(简洁) - business_type: 类型(main=主营, new=新增, stopped=已停止) - description: 业务描述 - status: 状态(active=活跃, stopped=已停止) 示例: ```json [ {{"business_name": "运营商增值服务", "business_type": "main", "description": "为运营商提供增值业务", "status": "active"}}, {{"business_name": "AI大模型应用", "business_type": "new", "description": "AI应用开发与落地", "status": "active"}} ] ``` 只返回JSON数组,不要其他解释。""" try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": business_prompt} ]) content = response.content if hasattr(response, 'content') else str(response) # 提取JSON json_match = re.search(r'\[.*\]', content, re.DOTALL) if json_match: businesses_data = json.loads(json_match.group()) businesses = [BusinessNode(**b) for b in businesses_data] else: businesses = [] logger.warning("LLM 未返回有效业务线JSON") except Exception as e: logger.error(f"业务线提取失败: {e}") businesses = [] # 行业节点 industries = [] if stock_info.get('industry'): industries.append(IndustryNode( industry_name=stock_info['industry'], level=1 )) # 返回基础图谱 return CompanyKnowledgeGraph( company=company, name_variants=name_variants, businesses=businesses, industries=industries, products=[], keywords=[], concepts=[] ) async def extract_from_news( self, stock_code: str, stock_name: str, news_list: List[Dict[str, Any]] ) -> Dict[str, Any]: """ 从新闻中提取业务变化和概念 Args: stock_code: 股票代码 stock_name: 股票名称 news_list: 新闻列表 Returns: 提取的信息 """ if not news_list: return { "new_businesses": [], "stopped_businesses": [], "new_products": [], "new_concepts": [] } # 获取当前时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") # 汇总新闻 news_summary = "\n\n".join([ f"【{i+1}】{news.get('title', '')}\n{news.get('content', '')[:300]}..." for i, news in enumerate(news_list[:10]) ]) prompt = f"""请分析以下新闻,提取{stock_name}公司的业务变化和相关概念: 【当前时间】 {current_time} 【公司】{stock_name}({stock_code}) 【近期新闻】 {news_summary} 请从新闻中提取: 1. **新增业务线**:公司新开拓的业务方向 2. **停止业务线**:公司明确表示停止或退出的业务 3. **新产品/服务**:公司推出的新产品或服务 4. **关联概念**:新闻中提到的热门概念(如 AI大模型、云计算、元宇宙等) 以JSON格式返回: ```json {{ "new_businesses": ["业务1", "业务2"], "stopped_businesses": ["业务3"], "new_products": ["产品1", "产品2"], "new_concepts": ["概念1", "概念2"] }} ``` 注意: - 只提取明确的信息,不要臆测 - 如果没有相关信息,返回空数组 - 只返回JSON,不要其他文字 JSON:""" try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": prompt} ]) content = response.content if hasattr(response, 'content') else str(response) # 提取JSON import re json_match = re.search(r'\{.*\}', content, re.DOTALL) if json_match: extracted = json.loads(json_match.group()) logger.info(f"✅ 从新闻提取信息: {extracted}") return extracted else: logger.warning("LLM 未返回有效JSON") return { "new_businesses": [], "stopped_businesses": [], "new_products": [], "new_concepts": [] } except Exception as e: logger.error(f"新闻信息提取失败: {e}") return { "new_businesses": [], "stopped_businesses": [], "new_products": [], "new_concepts": [] } async def extract_from_document( self, stock_code: str, stock_name: str, document_content: str, document_type: str = "annual_report" ) -> Dict[str, Any]: """ 从PDF/Word文档提取深度信息 Args: stock_code: 股票代码 stock_name: 股票名称 document_content: 文档内容(已通过MinerU解析) document_type: 文档类型(annual_report=年报, announcement=公告) Returns: 提取的信息 """ # 获取当前时间 current_time = datetime.now().strftime("%Y年%m月%d日 %H:%M") prompt = f"""请从以下{stock_name}的{document_type}中提取详细的业务信息: 【当前时间】 {current_time} 【公司】{stock_name}({stock_code}) 【文档内容】(前3000字) {document_content[:3000]} 请提取: 1. **主营业务**:公司当前的核心业务(详细) 2. **新增业务**:文档中提到的新业务拓展 3. **主要产品**:公司的主要产品或服务 4. **行业定位**:所属行业和细分领域 5. **战略方向**:未来战略和关注的热点领域 以JSON格式返回: ```json {{ "main_businesses": [ {{"name": "业务1", "description": "详细描述"}} ], "new_businesses": [ {{"name": "业务2", "description": "详细描述"}} ], "products": [ {{"name": "产品1", "type": "software/hardware/service", "description": "描述"}} ], "industries": ["一级行业", "二级行业"], "concepts": ["概念1", "概念2"], "keywords": ["关键词1", "关键词2"] }} ``` 只返回JSON,不要其他解释。""" try: response = self._llm_provider.invoke([ {"role": "system", "content": f"你是{self.role},{self.backstory}"}, {"role": "user", "content": prompt} ]) content = response.content if hasattr(response, 'content') else str(response) # 提取JSON import re json_match = re.search(r'\{.*\}', content, re.DOTALL) if json_match: extracted = json.loads(json_match.group()) logger.info(f"✅ 从文档提取信息: {len(extracted.get('products', []))}个产品, {len(extracted.get('concepts', []))}个概念") return extracted else: logger.warning("LLM 未返回有效JSON") return {} except Exception as e: logger.error(f"文档信息提取失败: {e}") return {} class AkshareKnowledgeExtractor: """ 从 akshare 提取基础信息,构建简单图谱并生成搜索关键词 """ @staticmethod def extract_company_info(stock_code: str) -> Optional[Dict[str, Any]]: """ 从 akshare 获取公司基础信息 Args: stock_code: 股票代码 Returns: 公司信息字典 """ try: import akshare as ak # 提取纯数字代码 pure_code = stock_code if stock_code.startswith("SH") or stock_code.startswith("SZ"): pure_code = stock_code[2:] logger.info(f"🔍 从 akshare 获取公司信息: {pure_code}") # 获取个股信息 try: # 尝试获取实时行情(包含基本信息) stock_df = ak.stock_individual_info_em(symbol=pure_code) if stock_df is not None and not stock_df.empty: # 打印 DataFrame 结构用于调试 logger.info(f"📋 akshare 返回 DataFrame: columns={list(stock_df.columns)}, rows={len(stock_df)}") # 转换为字典 - 兼容不同的列名格式 info_dict = {} # 确定列名 columns = list(stock_df.columns) key_col = None value_col = None # 尝试找到 key 列 for col in ['item', '属性', 'name', '项目']: if col in columns: key_col = col break # 尝试找到 value 列 for col in ['value', '值', 'data', '数值']: if col in columns: value_col = col break # 如果只有两列,直接使用 if len(columns) == 2 and (key_col is None or value_col is None): key_col, value_col = columns[0], columns[1] if key_col and value_col: for _, row in stock_df.iterrows(): try: key = str(row[key_col]) if row[key_col] is not None else '' value = str(row[value_col]) if row[value_col] is not None else '' if key and value and key != 'nan' and value != 'nan': info_dict[key] = value except Exception as row_err: logger.debug(f"跳过行: {row_err}") continue else: logger.warning(f"⚠️ 无法识别 DataFrame 列结构: {columns}") logger.info(f"📊 解析到 {len(info_dict)} 个字段: {list(info_dict.keys())[:10]}...") # 提取关键字段 result = { "industry": info_dict.get("行业") or info_dict.get("所属行业"), "sector": info_dict.get("板块") or info_dict.get("所属板块"), "main_business": info_dict.get("主营业务") or info_dict.get("经营范围"), "total_market_cap": info_dict.get("总市值"), "listed_date": info_dict.get("上市时间"), "raw_data": info_dict } main_business_preview = (result.get('main_business') or '')[:30] logger.info(f"✅ 获取到公司信息: 行业={result.get('industry')}, 主营={main_business_preview}...") return result else: logger.warning(f"⚠️ akshare 未返回数据: {pure_code}") return None except Exception as e: logger.error(f"❌ akshare 查询失败: {e}", exc_info=True) return None except ImportError: logger.error("akshare 未安装") return None except Exception as e: logger.error(f"提取公司信息失败: {e}") return None @staticmethod def generate_search_keywords( stock_code: str, stock_name: str, akshare_info: Optional[Dict[str, Any]] = None ) -> Dict[str, List[str]]: """ 基于股票信息生成分层关键词 返回两类关键词: - core_keywords: 核心关键词(公司名、代码等,必须包含) - extension_keywords: 扩展关键词(行业、业务、人名等,用于组合) Args: stock_code: 股票代码(如 SZ000004) stock_name: 股票名称(如 *ST国华) akshare_info: akshare 返回的公司信息(可选) Returns: {"core_keywords": [...], "extension_keywords": [...]} """ core_keywords = [] extension_keywords = [] # 提取纯数字代码 pure_code = stock_code if stock_code.startswith("SH") or stock_code.startswith("SZ"): pure_code = stock_code[2:] # === 1. 核心关键词(必须包含,用于确保相关性)=== # 原始名称(如 *ST国华) core_keywords.append(stock_name) # 去除 ST 标记的名称(如 国华) clean_name = stock_name for prefix in ["*ST", "ST", "S*ST", "S"]: if clean_name.startswith(prefix): clean_name = clean_name[len(prefix):] break if clean_name != stock_name and len(clean_name) >= 2: core_keywords.append(clean_name) # 股票代码 core_keywords.append(pure_code) # 000004 core_keywords.append(stock_code) # SZ000004 # 小写变体(如 st国华) core_keywords.append(stock_name.lower()) if clean_name != stock_name: core_keywords.append(clean_name.lower()) # === 2. 扩展关键词(用于组合搜索,扩大范围)=== if akshare_info: raw_data = akshare_info.get("raw_data", {}) # 公司全称(从 raw_data 中提取) company_full_name = raw_data.get("公司名称", raw_data.get("公司全称")) if company_full_name and len(company_full_name) > 4: extension_keywords.append(company_full_name) # 行业(但不单独搜索) industry = akshare_info.get("industry") if industry: extension_keywords.append(industry) # 主营业务(提取关键词) main_business = akshare_info.get("main_business", "") if main_business: import re business_parts = re.split(r'[,,、;;。\s]+', main_business) for part in business_parts[:3]: # 只取前3个 if 3 <= len(part) <= 10: # 长度适中的词 extension_keywords.append(part) # 董事长、总经理等关键人物 ceo = raw_data.get("董事长", raw_data.get("总经理")) if ceo and 2 <= len(str(ceo)) <= 4: extension_keywords.append(str(ceo)) # 去重 core_keywords = list(dict.fromkeys(core_keywords)) extension_keywords = list(dict.fromkeys(extension_keywords)) logger.info( f"📋 生成分层关键词: 核心={len(core_keywords)}个{core_keywords[:5]}, " f"扩展={len(extension_keywords)}个{extension_keywords[:5]}" ) return { "core_keywords": core_keywords, "extension_keywords": extension_keywords } @staticmethod def build_simple_graph_from_info( stock_code: str, stock_name: str, akshare_info: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 基于 akshare 信息构建简单的知识图谱结构 即使 akshare 调用失败,也能基于股票名称构建基础图谱 Args: stock_code: 股票代码 stock_name: 股票名称 akshare_info: akshare 返回的公司信息(可选) Returns: 简单图谱结构 """ # 提取纯数字代码 pure_code = stock_code if stock_code.startswith("SH") or stock_code.startswith("SZ"): pure_code = stock_code[2:] # 构建基础图谱 graph = { "company": { "stock_code": stock_code, "stock_name": stock_name, "pure_code": pure_code }, "name_variants": [], "industries": [], "businesses": [], "keywords": [] } # === 1. 名称变体 === graph["name_variants"].append(stock_name) # 去除 ST 标记 clean_name = stock_name for prefix in ["*ST", "ST", "S*ST", "S"]: if clean_name.startswith(prefix): clean_name = clean_name[len(prefix):] break if clean_name != stock_name: graph["name_variants"].append(clean_name) # 简称(取前两个字) if len(clean_name) >= 2: graph["name_variants"].append(clean_name[:2]) # === 2. 基于 akshare 信息填充 === if akshare_info: # 行业 industry = akshare_info.get("industry") if industry: graph["industries"].append(industry) # 板块 sector = akshare_info.get("sector") if sector: graph["industries"].append(sector) # 主营业务 main_business = akshare_info.get("main_business", "") if main_business: graph["businesses"].append(main_business[:100]) # 截取前100字 # 提取业务关键词 import re business_parts = re.split(r'[,,、;;。\s]+', main_business) for part in business_parts[:5]: if 2 <= len(part) <= 10: graph["keywords"].append(part) # === 3. 生成搜索关键词(分层:核心 + 扩展) === keyword_groups = AkshareKnowledgeExtractor.generate_search_keywords( stock_code, stock_name, akshare_info ) graph["core_keywords"] = keyword_groups["core_keywords"] graph["extension_keywords"] = keyword_groups["extension_keywords"] logger.info(f"📊 构建简单图谱: 公司={stock_name}, 名称变体={len(graph['name_variants'])}个, " f"行业={len(graph['industries'])}个, " f"核心词={len(graph['core_keywords'])}个, 扩展词={len(graph['extension_keywords'])}个") return graph class NewsKnowledgeExtractor: """ 从新闻中提取业务变化 """ def __init__(self, extractor_agent: KnowledgeExtractorAgent): self.agent = extractor_agent async def extract_business_changes( self, stock_code: str, stock_name: str, news_list: List[Dict[str, Any]] ) -> Dict[str, Any]: """ 从新闻列表中提取业务变化 Args: stock_code: 股票代码 stock_name: 股票名称 news_list: 新闻列表 Returns: 业务变化信息 """ return await self.agent.extract_from_news(stock_code, stock_name, news_list) # 工厂函数 def create_knowledge_extractor(llm_provider=None) -> KnowledgeExtractorAgent: """创建知识提取智能体""" return KnowledgeExtractorAgent(llm_provider) ================================================ FILE: backend/app/knowledge/parallel_search.py ================================================ """ 并发多关键词检索策略 基于知识图谱的关键词,并发调用多个搜索API """ import logging import asyncio from typing import List, Dict, Any, Set from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from ..tools.bochaai_search import bochaai_search from .graph_models import SearchKeywordSet logger = logging.getLogger(__name__) class ParallelSearchStrategy: """ 并发检索策略 基于知识图谱生成的关键词,并发搜索获取更全面的新闻 """ def __init__(self, max_workers: int = 5): """ 初始化并发检索策略 Args: max_workers: 最大并发工作线程数 """ self.max_workers = max_workers def search_with_multiple_keywords( self, keyword_set: SearchKeywordSet, days: int = 30, max_results_per_query: int = 50 ) -> List[Dict[str, Any]]: """ 使用多个关键词并发搜索 Args: keyword_set: 关键词集合 days: 搜索天数 max_results_per_query: 每个查询的最大结果数 Returns: 去重后的新闻列表 """ # 生成多样化的搜索查询 queries = keyword_set.generate_search_queries(max_queries=10) logger.info(f"🔍 开始并发检索: {keyword_set.stock_name}, 查询数={len(queries)}") logger.info(f"📋 查询列表: {queries}") all_results = [] seen_urls: Set[str] = set() # 用于去重 # 并发执行搜索 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # 提交所有搜索任务 future_to_query = {} for query in queries: future = executor.submit( self._search_single_query, query, days, max_results_per_query ) future_to_query[future] = query # 收集结果 for future in as_completed(future_to_query): query = future_to_query[future] try: results = future.result() # 去重并添加 added_count = 0 for result in results: if result.url not in seen_urls: seen_urls.add(result.url) all_results.append(result) added_count += 1 logger.info(f"✅ 查询「{query}」完成: 返回{len(results)}条, 去重后新增{added_count}条") except Exception as e: logger.error(f"❌ 查询「{query}」失败: {e}") logger.info(f"🎉 并发检索完成: 共获取 {len(all_results)} 条去重后的新闻") return all_results def _search_single_query( self, query: str, days: int, count: int ) -> List[Any]: """ 执行单个查询(在线程中运行) Args: query: 搜索查询 days: 天数 count: 结果数 Returns: 搜索结果列表 """ try: if not bochaai_search.is_available(): return [] # 调用 BochaAI 搜索 results = bochaai_search.search( query=query, freshness="year", count=count, offset=0 ) return results except Exception as e: logger.error(f"搜索失败 {query}: {e}") return [] async def search_async( self, keyword_set: SearchKeywordSet, days: int = 30, max_results_per_query: int = 50 ) -> List[Dict[str, Any]]: """ 异步版本的并发搜索 Args: keyword_set: 关键词集合 days: 搜索天数 max_results_per_query: 每个查询的最大结果数 Returns: 去重后的新闻列表 """ # 在线程池中运行同步搜索 loop = asyncio.get_event_loop() return await loop.run_in_executor( None, self.search_with_multiple_keywords, keyword_set, days, max_results_per_query ) # 便捷函数 def create_parallel_search(max_workers: int = 5) -> ParallelSearchStrategy: """创建并发检索策略""" return ParallelSearchStrategy(max_workers=max_workers) ================================================ FILE: backend/app/main.py ================================================ """ FinnewsHunter 主应用入口 """ import logging from contextlib import asynccontextmanager from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.exceptions import RequestValidationError from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html, get_swagger_ui_oauth2_redirect_html from starlette.middleware.base import BaseHTTPMiddleware from .core.config import settings from .core.database import init_database from .api.v1 import api_router # 配置日志 logging.basicConfig( level=getattr(logging, settings.LOG_LEVEL), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class DocsCSPMiddleware(BaseHTTPMiddleware): """为文档页面设置 CSP 头,允许 unsafe-eval(Swagger UI 需要)""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # 只为文档页面设置 CSP if request.url.path in ["/docs", "/redoc", "/openapi.json"]: # 开发环境:完全禁用 CSP 限制(仅用于文档页面) # 生产环境应该使用更严格的策略 if settings.DEBUG: # 开发环境:允许所有内容(Swagger UI 需要) response.headers["Content-Security-Policy"] = ( "default-src * 'unsafe-inline' 'unsafe-eval' data: blob:; " "script-src * 'unsafe-inline' 'unsafe-eval'; " "style-src * 'unsafe-inline'; " "img-src * data: blob:; " "font-src * data:; " "connect-src *; " "frame-src *; " "object-src *; " "media-src *; " "worker-src * blob:; " "manifest-src *; " "form-action *; " "base-uri *; " "frame-ancestors *;" ) else: # 生产环境:使用较宽松但仍有限制的策略 response.headers["Content-Security-Policy"] = ( "default-src 'self' 'unsafe-inline' 'unsafe-eval' data: blob: https:; " "script-src 'self' 'unsafe-eval' 'unsafe-inline' https://cdn.jsdelivr.net https://unpkg.com; " "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://fonts.googleapis.com https://unpkg.com; " "font-src 'self' data: https://fonts.gstatic.com https://cdn.jsdelivr.net; " "img-src 'self' data: blob: https:; " "connect-src 'self' https:; " "frame-src 'self' https:; " "object-src 'none'; " "base-uri 'self'; " "form-action 'self'" ) return response @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" # 启动时执行 logger.info("=== FinnewsHunter Starting ===") logger.info(f"Environment: {'Development' if settings.DEBUG else 'Production'}") logger.info(f"LLM Provider: {settings.LLM_PROVIDER}/{settings.LLM_MODEL}") # 初始化 Neo4j 知识图谱(仅创建约束和索引,不构建具体图谱) try: from .core.neo4j_client import get_neo4j_client from .knowledge.graph_service import get_graph_service logger.info("🔍 初始化 Neo4j 知识图谱...") neo4j_client = get_neo4j_client() if neo4j_client.health_check(): logger.info("✅ Neo4j 连接正常") # 初始化约束和索引(由 graph_service 自动完成) graph_service = get_graph_service() logger.info("✅ Neo4j 约束和索引已就绪") logger.info("💡 提示: 首次定向爬取时会自动为股票构建知识图谱") else: logger.warning("⚠️ Neo4j 连接失败,知识图谱功能将不可用(不影响其他功能)") except Exception as e: logger.warning(f"⚠️ Neo4j 初始化失败: {e},知识图谱功能将不可用(不影响其他功能)") yield # 关闭时执行 logger.info("=== FinnewsHunter Shutting Down ===") # 关闭 Neo4j 连接 try: from .core.neo4j_client import close_neo4j_client close_neo4j_client() logger.info("✅ Neo4j 连接已关闭") except: pass # 创建 FastAPI 应用 # 禁用默认文档(我们将使用自定义 CDN) app = FastAPI( title=settings.APP_NAME, description="Financial News Analysis Platform powered by AgenticX", version=settings.APP_VERSION, debug=settings.DEBUG, lifespan=lifespan, docs_url=None, # 禁用默认文档,使用自定义路由 redoc_url=None, # 禁用默认 ReDoc,使用自定义路由 ) # 添加文档页面的 CSP 中间件(必须在 CORS 之前) app.add_middleware(DocsCSPMiddleware) # 配置 CORS # 开发环境允许所有来源(包括 file:// 协议) if settings.DEBUG: app.add_middleware( CORSMiddleware, allow_origins=["*"], # 开发环境允许所有来源 allow_credentials=False, # 允许所有来源时必须为 False allow_methods=["*"], allow_headers=["*"], ) else: # 生产环境只允许配置的来源 app.add_middleware( CORSMiddleware, allow_origins=settings.BACKEND_CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 请求验证错误处理(422错误) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): """处理请求验证错误(422)""" # 尝试读取请求体 body_str = "" try: body_bytes = await request.body() body_str = body_bytes.decode('utf-8') except Exception as e: logger.warning(f"Failed to read request body: {e}") logger.error(f"Validation error for {request.method} {request.url.path}") logger.error(f"Validation errors: {exc.errors()}") logger.error(f"Request body: {body_str}") return JSONResponse( status_code=422, content={ "detail": exc.errors(), "body": body_str if settings.DEBUG else None } ) # 全局异常处理 @app.exception_handler(Exception) async def global_exception_handler(request, exc): logger.error(f"Global exception: {exc}", exc_info=True) return JSONResponse( status_code=500, content={ "success": False, "error": "Internal server error", "detail": str(exc) if settings.DEBUG else None } ) # 根路由 @app.get("/") async def root(): """根路由 - 系统信息""" return { "name": settings.APP_NAME, "version": settings.APP_VERSION, "status": "active", "message": "Welcome to FinnewsHunter API", "docs_url": "/docs", "api_prefix": settings.API_V1_PREFIX, } # 健康检查 @app.get("/health") async def health_check(): """健康检查端点""" return { "status": "healthy", "app": settings.APP_NAME, "version": settings.APP_VERSION, } # 自定义 Swagger UI(使用 unpkg.com CDN,因为 jsdelivr.net 无法访问) @app.get("/docs", include_in_schema=False) @app.head("/docs", include_in_schema=False) async def custom_swagger_ui_html(): """自定义 Swagger UI,使用 unpkg.com CDN""" return get_swagger_ui_html( openapi_url=app.openapi_url, title=app.title + " - Swagger UI", oauth2_redirect_url="/docs/oauth2-redirect", swagger_js_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js", swagger_css_url="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css", swagger_favicon_url="https://fastapi.tiangolo.com/img/favicon.png", ) # Swagger UI OAuth2 重定向 @app.get("/docs/oauth2-redirect", include_in_schema=False) async def swagger_ui_redirect(): """Swagger UI OAuth2 重定向""" return get_swagger_ui_oauth2_redirect_html() # 自定义 ReDoc(使用 unpkg.com CDN) @app.get("/redoc", include_in_schema=False) @app.head("/redoc", include_in_schema=False) async def redoc_html(): """自定义 ReDoc,使用 unpkg.com CDN""" return get_redoc_html( openapi_url=app.openapi_url, title=app.title + " - ReDoc", redoc_js_url="https://unpkg.com/redoc@2/bundles/redoc.standalone.js", redoc_favicon_url="https://fastapi.tiangolo.com/img/favicon.png", ) # Chrome DevTools 配置文件(避免 404 日志) @app.get("/.well-known/appspecific/com.chrome.devtools.json") async def chrome_devtools_config(): """Chrome DevTools 配置文件""" return {} # 注册 API 路由 app.include_router(api_router, prefix=settings.API_V1_PREFIX) if __name__ == "__main__": import uvicorn uvicorn.run( "app.main:app", host=settings.HOST, port=settings.PORT, reload=settings.DEBUG, ) ================================================ FILE: backend/app/models/__init__.py ================================================ """ 数据模型模块 """ from .database import Base, get_async_session, get_sync_session, init_db from .news import News from .stock import Stock from .analysis import Analysis from .crawl_task import CrawlTask, CrawlMode, TaskStatus from .debate_history import DebateHistory __all__ = [ "Base", "get_async_session", "get_sync_session", "init_db", "News", "Stock", "Analysis", "CrawlTask", "CrawlMode", "TaskStatus", "DebateHistory", ] ================================================ FILE: backend/app/models/analysis.py ================================================ """ 分析结果数据模型 """ from datetime import datetime from sqlalchemy import Column, Integer, String, Text, DateTime, Float, ForeignKey, JSON from sqlalchemy.orm import relationship from .database import Base class Analysis(Base): """智能体分析结果表""" __tablename__ = "analyses" # 主键 id = Column(Integer, primary_key=True, index=True, autoincrement=True) # 关联新闻 news_id = Column(Integer, ForeignKey("news.id", ondelete="CASCADE"), nullable=False, index=True) # 智能体信息 agent_name = Column(String(100), nullable=False, comment="执行分析的智能体名称") agent_role = Column(String(100), nullable=True, comment="智能体角色") # 分析结果 analysis_result = Column(Text, nullable=False, comment="分析结果(完整文本)") summary = Column(Text, nullable=True, comment="分析摘要") # 情感分析 sentiment = Column(String(20), nullable=True, comment="情感倾向(positive, negative, neutral)") sentiment_score = Column(Float, nullable=True, comment="情感评分(-1到1)") confidence = Column(Float, nullable=True, comment="置信度(0到1)") # 结构化数据 structured_data = Column(JSON, nullable=True, comment="结构化分析数据(JSON格式)") # 元数据 execution_time = Column(Float, nullable=True, comment="执行时间(秒)") llm_model = Column(String(100), nullable=True, comment="使用的LLM模型") tokens_used = Column(Integer, nullable=True, comment="消耗的Token数") # 时间戳 created_at = Column(DateTime, default=datetime.utcnow, nullable=False) # 关系 news = relationship("News", back_populates="analyses") def __repr__(self): return f"" def to_dict(self): """转换为字典""" return { "id": self.id, "news_id": self.news_id, "agent_name": self.agent_name, "agent_role": self.agent_role, "analysis_result": self.analysis_result, "summary": self.summary, "sentiment": self.sentiment, "sentiment_score": self.sentiment_score, "confidence": self.confidence, "structured_data": self.structured_data, "execution_time": self.execution_time, "llm_model": self.llm_model, "tokens_used": self.tokens_used, "created_at": self.created_at.isoformat() if self.created_at else None, } ================================================ FILE: backend/app/models/crawl_task.py ================================================ """ 爬取任务数据模型 """ from datetime import datetime from typing import Optional from sqlalchemy import Column, Integer, String, DateTime, JSON, Float from enum import Enum from .database import Base class CrawlMode(str, Enum): """爬取模式枚举""" COLD_START = "cold_start" # 冷启动(批量历史) REALTIME = "realtime" # 实时监控 TARGETED = "targeted" # 定向分析 class TaskStatus(str, Enum): """任务状态枚举""" PENDING = "pending" # 待执行 RUNNING = "running" # 执行中 COMPLETED = "completed" # 已完成 FAILED = "failed" # 失败 CANCELLED = "cancelled" # 已取消 class CrawlTask(Base): """爬取任务表""" __tablename__ = "crawl_tasks" # 主键 id = Column(Integer, primary_key=True, index=True, autoincrement=True) # 任务信息 celery_task_id = Column(String(255), unique=True, nullable=True, index=True, comment="Celery任务ID") mode = Column(String(20), nullable=False, index=True, comment="爬取模式") status = Column(String(20), nullable=False, default=TaskStatus.PENDING, index=True, comment="任务状态") # 任务配置 source = Column(String(100), nullable=False, comment="新闻源") config = Column(JSON, nullable=True, comment="任务配置(JSON)") # 执行进度 progress = Column(JSON, nullable=True, comment="进度信息") current_page = Column(Integer, nullable=True, comment="当前页码") total_pages = Column(Integer, nullable=True, comment="总页数") # 执行结果 result = Column(JSON, nullable=True, comment="结果统计(JSON)") crawled_count = Column(Integer, default=0, comment="爬取到的新闻数") saved_count = Column(Integer, default=0, comment="保存到数据库的新闻数") error_message = Column(String(1000), nullable=True, comment="错误信息") # 性能指标 execution_time = Column(Float, nullable=True, comment="执行时间(秒)") # 时间戳 created_at = Column(DateTime, default=datetime.utcnow, nullable=False, comment="创建时间") started_at = Column(DateTime, nullable=True, comment="开始时间") completed_at = Column(DateTime, nullable=True, comment="完成时间") def __repr__(self): return f"" def to_dict(self): """转换为字典""" return { "id": self.id, "celery_task_id": self.celery_task_id, "mode": self.mode, "status": self.status, "source": self.source, "config": self.config, "progress": self.progress, "current_page": self.current_page, "total_pages": self.total_pages, "result": self.result, "crawled_count": self.crawled_count, "saved_count": self.saved_count, "error_message": self.error_message, "execution_time": self.execution_time, "created_at": self.created_at.isoformat() if self.created_at else None, "started_at": self.started_at.isoformat() if self.started_at else None, "completed_at": self.completed_at.isoformat() if self.completed_at else None, } ================================================ FILE: backend/app/models/database.py ================================================ """ 数据库连接和会话管理 """ from typing import AsyncGenerator from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.orm import sessionmaker, declarative_base, Session from ..core.config import settings # 声明基类 Base = declarative_base() # 异步引擎(用于应用运行时) async_engine = create_async_engine( settings.DATABASE_URL, echo=settings.DEBUG, pool_pre_ping=True, pool_size=10, max_overflow=20, ) # 异步会话工厂 AsyncSessionLocal = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) # 同步引擎(用于数据库初始化) sync_engine = create_engine( settings.SYNC_DATABASE_URL, echo=settings.DEBUG, pool_pre_ping=True, ) # 同步会话工厂 SyncSessionLocal = sessionmaker( bind=sync_engine, autocommit=False, autoflush=False, ) async def get_async_session() -> AsyncGenerator[AsyncSession, None]: """ 异步数据库会话依赖注入 Yields: AsyncSession: 数据库会话 """ async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() def get_sync_session() -> Session: """ 同步数据库会话(用于初始化脚本) Returns: Session: 数据库会话 """ session = SyncSessionLocal() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close() def init_db(): """ 初始化数据库表 在首次运行或重置数据库时调用 """ from .news import News from .stock import Stock from .analysis import Analysis print("Creating database tables...") Base.metadata.create_all(bind=sync_engine) print("Database tables created successfully!") if __name__ == "__main__": # 直接运行此文件以初始化数据库 init_db() ================================================ FILE: backend/app/models/debate_history.py ================================================ """ 辩论历史数据模型 """ from datetime import datetime from typing import List, Optional from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Index from .database import Base class DebateHistory(Base): """辩论历史表模型""" __tablename__ = "debate_histories" # 主键 id = Column(Integer, primary_key=True, index=True, autoincrement=True) # 会话标识 session_id = Column(String(100), unique=True, nullable=False, index=True, comment="会话ID") # 股票信息 stock_code = Column(String(20), nullable=False, index=True, comment="股票代码") stock_name = Column(String(100), nullable=True, comment="股票名称") # 辩论模式 mode = Column(String(50), nullable=True, comment="辩论模式(parallel/realtime_debate/quick_analysis)") # 聊天消息(JSON数组) messages = Column(JSON, nullable=False, default=list, comment="聊天消息数组") # 时间信息 created_at = Column(DateTime, default=datetime.utcnow, nullable=False, comment="创建时间") updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间") # 索引 __table_args__ = ( # 按股票+时间查询 Index('idx_debate_stock_updated', 'stock_code', 'updated_at'), ) def __repr__(self): return f"" def to_dict(self): """转换为字典""" return { "id": self.id, "session_id": self.session_id, "stock_code": self.stock_code, "stock_name": self.stock_name, "mode": self.mode, "messages": self.messages, "created_at": self.created_at.isoformat() if self.created_at else None, "updated_at": self.updated_at.isoformat() if self.updated_at else None, } ================================================ FILE: backend/app/models/news.py ================================================ """ 新闻数据模型 - Phase 2 索引优化 """ from datetime import datetime from typing import List, Optional from sqlalchemy import Column, Integer, String, Text, DateTime, Float, ARRAY, Index from sqlalchemy.orm import relationship from .database import Base class News(Base): """新闻表模型 - Phase 2 优化版""" __tablename__ = "news" # 主键 id = Column(Integer, primary_key=True, index=True, autoincrement=True) # 基本信息 title = Column(String(500), nullable=False, index=True, comment="新闻标题") content = Column(Text, nullable=False, comment="新闻正文(解析后)") raw_html = Column(Text, nullable=True, comment="原始HTML内容") url = Column(String(1000), unique=True, nullable=False, index=True, comment="新闻URL") source = Column(String(100), nullable=False, index=True, comment="新闻来源(sina, jrj, cnstock等)") # 时间信息 publish_time = Column(DateTime, nullable=True, index=True, comment="发布时间") created_at = Column(DateTime, default=datetime.utcnow, nullable=False, comment="爬取时间") updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间") # 关联股票 stock_codes = Column(ARRAY(String), nullable=True, comment="关联的股票代码列表") # 情感分析 sentiment_score = Column(Float, nullable=True, comment="情感评分(-1到1,负面到正面)") # 其他元数据 author = Column(String(200), nullable=True, comment="作者") keywords = Column(ARRAY(String), nullable=True, comment="关键词") # 向量化标识 is_embedded = Column(Integer, default=0, comment="是否已向量化(0:否, 1:是)") # 关系 analyses = relationship("Analysis", back_populates="news", cascade="all, delete-orphan") # Phase 2: 复合索引优化(提升常见查询性能) __table_args__ = ( # 按来源+时间查询(最常用) Index('idx_source_publish_time', 'source', 'publish_time'), # 按情感+时间筛选 Index('idx_sentiment_publish_time', 'sentiment_score', 'publish_time'), ) def __repr__(self): return f"" def to_dict(self, include_html: bool = False): """转换为字典""" result = { "id": self.id, "title": self.title, "content": self.content, "url": self.url, "source": self.source, "publish_time": self.publish_time.isoformat() if self.publish_time else None, "created_at": self.created_at.isoformat() if self.created_at else None, "stock_codes": self.stock_codes, "sentiment_score": self.sentiment_score, "author": self.author, "keywords": self.keywords, "has_raw_html": self.raw_html is not None and len(self.raw_html or '') > 0, } if include_html and self.raw_html: result["raw_html"] = self.raw_html return result ================================================ FILE: backend/app/models/stock.py ================================================ """ 股票数据模型 """ from datetime import datetime from sqlalchemy import Column, Integer, String, DateTime, Float from .database import Base class Stock(Base): """股票基本信息表""" __tablename__ = "stocks" # 主键 id = Column(Integer, primary_key=True, index=True, autoincrement=True) # 股票基本信息 code = Column(String(20), unique=True, nullable=False, index=True, comment="股票代码(如:600519)") name = Column(String(100), nullable=False, comment="股票名称(如:贵州茅台)") full_code = Column(String(20), nullable=True, comment="完整代码(如:SH600519)") # 分类信息 industry = Column(String(100), nullable=True, comment="所属行业") market = Column(String(20), nullable=True, comment="所属市场(SH:上海, SZ:深圳)") area = Column(String(50), nullable=True, comment="所属地区") # 财务指标(可选,后续扩展) pe_ratio = Column(Float, nullable=True, comment="市盈率") market_cap = Column(Float, nullable=True, comment="总市值") # 状态 status = Column(String(20), default="active", comment="状态(active, suspended, delisted)") # 时间戳 created_at = Column(DateTime, default=datetime.utcnow, nullable=False) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) def __repr__(self): return f"" def to_dict(self): """转换为字典""" return { "id": self.id, "code": self.code, "name": self.name, "full_code": self.full_code, "industry": self.industry, "market": self.market, "area": self.area, "pe_ratio": self.pe_ratio, "market_cap": self.market_cap, "status": self.status, "created_at": self.created_at.isoformat() if self.created_at else None, "updated_at": self.updated_at.isoformat() if self.updated_at else None, } ================================================ FILE: backend/app/scripts/init_stocks.py ================================================ """ 初始化股票数据脚本 从 akshare 获取全部 A 股信息并存入 PostgreSQL 使用方法: cd backend python -m app.scripts.init_stocks """ import asyncio import logging import os from datetime import datetime from pathlib import Path # ⚠️ 禁用代理(akshare 需要直连国内网站) for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'all_proxy', 'ALL_PROXY']: os.environ.pop(proxy_var, None) # 设置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 加载 .env from dotenv import load_dotenv env_path = Path(__file__).parent.parent.parent / ".env" load_dotenv(env_path) logger.info(f"Loaded .env from: {env_path}") # 构建数据库 URL DATABASE_URL = os.getenv("DATABASE_URL", "") if not DATABASE_URL: # 从分开的变量构建 DATABASE_URL pg_user = os.getenv("POSTGRES_USER", "finnews") pg_password = os.getenv("POSTGRES_PASSWORD", "finnews_dev_password") pg_host = os.getenv("POSTGRES_HOST", "localhost") pg_port = os.getenv("POSTGRES_PORT", "5432") pg_db = os.getenv("POSTGRES_DB", "finnews_db") DATABASE_URL = f"postgresql+asyncpg://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}" logger.info(f"Built DATABASE_URL from individual variables") elif DATABASE_URL.startswith("postgresql://"): DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1) logger.info(f"Database: {DATABASE_URL.split('@')[-1] if '@' in DATABASE_URL else DATABASE_URL[:30]}...") # 导入依赖 try: import akshare as ak import pandas as pd AKSHARE_AVAILABLE = True logger.info("akshare loaded successfully") except ImportError: AKSHARE_AVAILABLE = False logger.error("akshare not installed! Run: pip install akshare") exit(1) from sqlalchemy import Column, Integer, String, DateTime, Float, text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base Base = declarative_base() class Stock(Base): """股票基本信息表""" __tablename__ = "stocks" id = Column(Integer, primary_key=True, index=True, autoincrement=True) code = Column(String(20), unique=True, nullable=False, index=True) name = Column(String(100), nullable=False) full_code = Column(String(20), nullable=True) industry = Column(String(100), nullable=True) market = Column(String(20), nullable=True) area = Column(String(50), nullable=True) pe_ratio = Column(Float, nullable=True) market_cap = Column(Float, nullable=True) status = Column(String(20), default="active") created_at = Column(DateTime, default=datetime.utcnow, nullable=False) updated_at = Column(DateTime, default=datetime.utcnow) def get_fallback_stocks() -> list: """备用股票列表(如果 akshare 失败时使用)""" return [ {"code": "600519", "name": "贵州茅台", "full_code": "SH600519", "market": "SH", "status": "active"}, {"code": "000001", "name": "平安银行", "full_code": "SZ000001", "market": "SZ", "status": "active"}, {"code": "601318", "name": "中国平安", "full_code": "SH601318", "market": "SH", "status": "active"}, {"code": "000858", "name": "五粮液", "full_code": "SZ000858", "market": "SZ", "status": "active"}, {"code": "002594", "name": "比亚迪", "full_code": "SZ002594", "market": "SZ", "status": "active"}, {"code": "600036", "name": "招商银行", "full_code": "SH600036", "market": "SH", "status": "active"}, {"code": "601166", "name": "兴业银行", "full_code": "SH601166", "market": "SH", "status": "active"}, {"code": "000333", "name": "美的集团", "full_code": "SZ000333", "market": "SZ", "status": "active"}, {"code": "002415", "name": "海康威视", "full_code": "SZ002415", "market": "SZ", "status": "active"}, {"code": "600276", "name": "恒瑞医药", "full_code": "SH600276", "market": "SH", "status": "active"}, {"code": "000002", "name": "万科A", "full_code": "SZ000002", "market": "SZ", "status": "active"}, {"code": "600887", "name": "伊利股份", "full_code": "SH600887", "market": "SH", "status": "active"}, {"code": "000725", "name": "京东方A", "full_code": "SZ000725", "market": "SZ", "status": "active"}, {"code": "600000", "name": "浦发银行", "full_code": "SH600000", "market": "SH", "status": "active"}, {"code": "000063", "name": "中兴通讯", "full_code": "SZ000063", "market": "SZ", "status": "active"}, {"code": "600104", "name": "上汽集团", "full_code": "SH600104", "market": "SH", "status": "active"}, {"code": "002304", "name": "洋河股份", "full_code": "SZ002304", "market": "SZ", "status": "active"}, {"code": "600585", "name": "海螺水泥", "full_code": "SH600585", "market": "SH", "status": "active"}, {"code": "000876", "name": "新希望", "full_code": "SZ000876", "market": "SZ", "status": "active"}, {"code": "600309", "name": "万华化学", "full_code": "SH600309", "market": "SH", "status": "active"}, ] async def fetch_all_stocks() -> list: """从 akshare 获取全部 A 股信息""" logger.info("Fetching all A-share stocks from akshare...") # 设置 requests 不使用代理 import requests session = requests.Session() session.proxies = { 'http': None, 'https': None, } # 设置 User-Agent session.headers.update({ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' }) max_retries = 3 for attempt in range(max_retries): try: logger.info(f"Attempt {attempt + 1}/{max_retries}...") # 方法1: 尝试使用 stock_zh_a_spot_em try: df = ak.stock_zh_a_spot_em() except Exception as e1: logger.warning(f"Method 1 failed: {e1}") # 方法2: 尝试使用 stock_info_a_code_name try: logger.info("Trying alternative method: stock_info_a_code_name...") df = ak.stock_info_a_code_name() if df is not None and not df.empty: # 重命名列 df.columns = ['代码', '名称'] except Exception as e2: logger.warning(f"Method 2 failed: {e2}") raise e1 # 抛出第一个错误 if df is None or df.empty: logger.error("No data returned from akshare") if attempt < max_retries - 1: await asyncio.sleep(2) # 等待2秒后重试 continue return [] logger.info(f"✅ Fetched {len(df)} stocks from akshare") stocks = [] for _, row in df.iterrows(): code = str(row['代码']) name = str(row['名称']) # 跳过异常数据 if not code or not name or name in ['N/A', 'nan', '']: continue # 确定市场前缀 if code.startswith('6'): market = "SH" full_code = f"SH{code}" elif code.startswith('0') or code.startswith('3'): market = "SZ" full_code = f"SZ{code}" else: market = "OTHER" full_code = code stocks.append({ "code": code, "name": name, "full_code": full_code, "market": market, "status": "active", }) return stocks except Exception as e: logger.error(f"Attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: wait_time = (attempt + 1) * 2 logger.info(f"Waiting {wait_time} seconds before retry...") await asyncio.sleep(wait_time) else: logger.error("All attempts failed!") import traceback traceback.print_exc() return [] return [] async def init_stocks_to_db(): """初始化股票数据到数据库""" # 创建数据库引擎 engine = create_async_engine(DATABASE_URL, echo=False) async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) # 确保表存在 async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) # 获取股票数据 stocks_data = await fetch_all_stocks() if not stocks_data: logger.warning("⚠️ Failed to fetch from akshare, using fallback stock list...") # 备用方案:导入常用股票 stocks_data = get_fallback_stocks() if not stocks_data: logger.error("No stocks to insert") await engine.dispose() return logger.info(f"Using {len(stocks_data)} fallback stocks") async with async_session() as session: try: # 清空现有数据 logger.info("Clearing existing stock data...") await session.execute(text("DELETE FROM stocks")) await session.commit() # 批量插入 logger.info(f"Inserting {len(stocks_data)} stocks...") batch_size = 500 for i in range(0, len(stocks_data), batch_size): batch = stocks_data[i:i + batch_size] for stock_data in batch: stock = Stock( code=stock_data["code"], name=stock_data["name"], full_code=stock_data["full_code"], market=stock_data["market"], status=stock_data["status"], created_at=datetime.utcnow(), updated_at=datetime.utcnow(), ) session.add(stock) await session.commit() logger.info(f"Inserted batch {i // batch_size + 1}, total: {min(i + batch_size, len(stocks_data))}/{len(stocks_data)}") logger.info(f"✅ Successfully initialized {len(stocks_data)} stocks!") except Exception as e: logger.error(f"Failed to insert stocks: {e}") import traceback traceback.print_exc() await session.rollback() finally: await engine.dispose() async def get_stock_count(): """获取数据库中股票数量""" engine = create_async_engine(DATABASE_URL, echo=False) async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: result = await session.execute(text("SELECT COUNT(*) FROM stocks")) count = result.scalar() or 0 logger.info(f"Current stock count in database: {count}") await engine.dispose() return count async def main(): print("=" * 60) print("🚀 Stock Data Initialization Script") print("=" * 60) # 检查当前数量 try: await get_stock_count() except Exception as e: logger.warning(f"Could not get current count (table may not exist): {e}") # 执行初始化 print("\n📥 Starting initialization...") await init_stocks_to_db() # 再次检查 print("\n📊 After initialization:") await get_stock_count() print("\n✅ Done!") if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: backend/app/services/__init__.py ================================================ """ 服务模块 """ from .llm_service import get_llm_provider, get_llm_service, LLMService from .embedding_service import get_embedding_service, EmbeddingService from .analysis_service import get_analysis_service, AnalysisService __all__ = [ "get_llm_provider", "get_llm_service", "LLMService", "get_embedding_service", "EmbeddingService", "get_analysis_service", "AnalysisService", ] ================================================ FILE: backend/app/services/analysis_service.py ================================================ """ 新闻分析服务 协调智能体执行分析任务 """ import logging import time from typing import Dict, Any, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from starlette.concurrency import run_in_threadpool from ..models.database import AsyncSessionLocal from ..agents import create_news_analyst from ..models.news import News from ..models.analysis import Analysis from ..services.embedding_service import get_embedding_service from ..storage.vector_storage import get_vector_storage logger = logging.getLogger(__name__) class AnalysisService: """ 新闻分析服务 负责协调智能体执行新闻分析任务 """ def __init__(self): """初始化分析服务""" self.news_analyst = create_news_analyst() self.embedding_service = get_embedding_service() self.vector_storage = get_vector_storage() logger.info("Initialized AnalysisService") async def analyze_news( self, news_id: int, db: AsyncSession, llm_provider: Optional[str] = None, llm_model: Optional[str] = None ) -> Dict[str, Any]: """ 分析指定新闻 Args: news_id: 新闻ID db: 数据库会话 llm_provider: 模型厂商(可选:bailian, openai, deepseek, kimi) llm_model: 模型名称(可选) Returns: 分析结果 """ start_time = time.time() # 如果指定了自定义模型,创建临时的智能体 if llm_provider and llm_model: from ..services.llm_service import create_custom_llm_provider from ..agents.news_analyst import NewsAnalystAgent logger.info(f"Using custom model: {llm_provider}/{llm_model}") custom_llm = create_custom_llm_provider(llm_provider, llm_model) analyst = NewsAnalystAgent(llm_provider=custom_llm) else: analyst = self.news_analyst try: # 1. 查询新闻 result = await db.execute( select(News).where(News.id == news_id) ) news = result.scalar_one_or_none() if not news: return { "success": False, "error": f"News not found: {news_id}" } logger.info(f"Analyzing news: {news_id} - {news.title}") # 2. 执行智能体分析 # 注意:由于 agent.analyze_news 是同步方法,需要在线程池中运行以避免阻塞异步事件循环 analysis_result = await run_in_threadpool( analyst.analyze_news, # 使用 analyst(可能是自定义的或默认的) news_title=news.title, news_content=news.content, news_url=news.url, stock_codes=news.stock_codes or [] ) if not analysis_result.get("success"): return analysis_result # 3. 保存分析结果到数据库 structured_data = analysis_result.get("structured_data", {}) analysis = Analysis( news_id=news_id, agent_name=analysis_result.get("agent_name"), agent_role=analysis_result.get("agent_role"), analysis_result=analysis_result.get("analysis_result", ""), summary=structured_data.get("market_impact", "")[:500], sentiment=structured_data.get("sentiment"), sentiment_score=structured_data.get("sentiment_score"), confidence=structured_data.get("confidence"), structured_data=structured_data, execution_time=time.time() - start_time, llm_model=f"{llm_provider}/{llm_model}" if llm_provider and llm_model else (analyst._llm_provider.model if hasattr(analyst, '_llm_provider') and hasattr(analyst._llm_provider, 'model') else None), ) db.add(analysis) # 4. 更新新闻的情感评分 news.sentiment_score = structured_data.get("sentiment_score") # 5. 向量化新闻内容(如果尚未向量化) # 注意:embedding是可选功能,失败不应影响分析结果 # 在后台异步执行,不阻塞分析流程 if not news.is_embedded: # 使用 asyncio.create_task 在后台执行,不等待结果 # 这样即使embedding超时或失败,也不会影响分析结果的返回 import asyncio async def vectorize_in_background(): try: # 组合标题和内容进行向量化 text_to_embed = f"{news.title}\n{news.content[:1000]}" # 使用异步方法,避免事件循环问题 embedding = await asyncio.wait_for( self.embedding_service.aembed_text(text_to_embed), timeout=20.0 # 20秒超时,避免等待太久 ) # 存储到 Milvus(也在线程池中执行) await run_in_threadpool( self.vector_storage.store_embedding, news_id=news_id, embedding=embedding, text=text_to_embed ) # 更新数据库中的is_embedded标志(需要新的数据库会话) async with AsyncSessionLocal() as update_db: try: result = await update_db.execute( select(News).where(News.id == news_id) ) update_news = result.scalar_one_or_none() if update_news: update_news.is_embedded = 1 await update_db.commit() logger.info(f"Vectorized news: {news_id}") except Exception as e: logger.warning(f"Failed to update is_embedded flag for news {news_id}: {e}") await update_db.rollback() except asyncio.TimeoutError: logger.warning(f"Embedding timeout for news {news_id} (20s), skipping vectorization") except Exception as e: logger.warning(f"Failed to vectorize news {news_id}: {e}") # 在后台执行,不等待完成 asyncio.create_task(vectorize_in_background()) await db.commit() await db.refresh(analysis) logger.info(f"Analysis completed for news {news_id}, execution time: {analysis.execution_time:.2f}s") return { "success": True, "analysis_id": analysis.id, "news_id": news_id, "sentiment": analysis.sentiment, "sentiment_score": analysis.sentiment_score, "confidence": analysis.confidence, "summary": analysis.summary, "execution_time": analysis.execution_time, } except Exception as e: logger.error(f"Analysis failed for news {news_id}: {e}") await db.rollback() return { "success": False, "error": str(e) } async def get_analysis_by_id( self, analysis_id: int, db: AsyncSession ) -> Optional[Dict[str, Any]]: """ 获取分析结果 Args: analysis_id: 分析ID db: 数据库会话 Returns: 分析结果或None """ try: result = await db.execute( select(Analysis).where(Analysis.id == analysis_id) ) analysis = result.scalar_one_or_none() if analysis: return analysis.to_dict() return None except Exception as e: logger.error(f"Failed to get analysis {analysis_id}: {e}") return None async def get_analyses_by_news_id( self, news_id: int, db: AsyncSession ) -> list: """ 获取指定新闻的所有分析结果(按时间倒序,最新的在前) Args: news_id: 新闻ID db: 数据库会话 Returns: 分析结果列表(最新的在前) """ try: from sqlalchemy import desc result = await db.execute( select(Analysis) .where(Analysis.news_id == news_id) .order_by(desc(Analysis.created_at)) # 按创建时间倒序,最新的在前 ) analyses = result.scalars().all() return [analysis.to_dict() for analysis in analyses] except Exception as e: logger.error(f"Failed to get analyses for news {news_id}: {e}") return [] # 全局实例 _analysis_service: Optional[AnalysisService] = None def get_analysis_service() -> AnalysisService: """ 获取分析服务实例(单例模式) Returns: AnalysisService 实例 """ global _analysis_service if _analysis_service is None: _analysis_service = AnalysisService() return _analysis_service ================================================ FILE: backend/app/services/embedding_service.py ================================================ """ Embedding 服务封装 使用 agenticx.embeddings.BailianEmbeddingProvider """ import logging import asyncio from typing import List, Optional import redis import hashlib import json from ..core.config import settings from agenticx.embeddings import BailianEmbeddingProvider logger = logging.getLogger(__name__) class EmbeddingService: """ Embedding 服务封装类 基于 agenticx.embeddings.BailianEmbeddingProvider 提供文本向量化功能,支持缓存 """ def __init__( self, provider: str = None, model: str = None, batch_size: int = None, enable_cache: bool = True, base_url: str = None, ): """ 初始化 Embedding 服务 Args: provider: 提供商(保留参数以兼容,实际使用 bailian) model: 模型名称 batch_size: 批处理大小 enable_cache: 是否启用Redis缓存 base_url: 自定义 API 端点(用于百炼等第三方服务) """ self.provider = provider or settings.EMBEDDING_PROVIDER self.model = model or settings.EMBEDDING_MODEL self.batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE self.enable_cache = enable_cache self.base_url = base_url or settings.EMBEDDING_BASE_URL # 获取 API Key api_key = settings.DASHSCOPE_API_KEY if not api_key: # 如果没有 DASHSCOPE_API_KEY,尝试使用 OPENAI_API_KEY(向后兼容) api_key = settings.OPENAI_API_KEY if not api_key: raise ValueError("DASHSCOPE_API_KEY or OPENAI_API_KEY is required for embedding") # 设置 API URL api_url = self.base_url or settings.DASHSCOPE_BASE_URL or "https://dashscope.aliyuncs.com/compatible-mode/v1" # 初始化 agenticx BailianEmbeddingProvider self.provider_instance = BailianEmbeddingProvider( api_key=api_key, model=self.model, api_url=api_url, batch_size=self.batch_size, timeout=settings.EMBEDDING_TIMEOUT, retry_count=settings.EMBEDDING_MAX_RETRIES, dimensions=settings.MILVUS_DIM, # 确保维度匹配 use_dashscope_sdk=False # 使用 HTTP API,避免 SDK 依赖问题 ) logger.info(f"Initialized BailianEmbeddingProvider: {self.model}, dimension={self.provider_instance.get_embedding_dim()}") # 初始化Redis缓存 if self.enable_cache: try: self.redis_client = redis.from_url(settings.REDIS_URL) self.cache_ttl = 86400 * 7 # 7天 logger.info("Redis cache enabled for embeddings") except Exception as e: logger.warning(f"Failed to connect to Redis, cache disabled: {e}") self.enable_cache = False def _get_cache_key(self, text: str) -> str: """生成缓存键""" # 使用文本的MD5哈希和模型名称作为键 text_hash = hashlib.md5(text.encode()).hexdigest() return f"embedding:{self.model}:{text_hash}" def _get_from_cache(self, text: str) -> Optional[List[float]]: """从缓存获取向量""" if not self.enable_cache: return None try: cache_key = self._get_cache_key(text) cached = self.redis_client.get(cache_key) if cached: return json.loads(cached) except Exception as e: logger.warning(f"Failed to get from cache: {e}") return None def _save_to_cache(self, text: str, embedding: List[float]): """保存向量到缓存""" if not self.enable_cache: return try: cache_key = self._get_cache_key(text) self.redis_client.setex( cache_key, self.cache_ttl, json.dumps(embedding) ) except Exception as e: logger.warning(f"Failed to save to cache: {e}") def embed_text(self, text: str) -> List[float]: """ 将文本转换为向量 Args: text: 文本 Returns: 向量(List[float]) """ # 检查缓存 cached = self._get_from_cache(text) if cached is not None: return cached # 限制文本长度(避免超过模型限制) max_length = 6000 if len(text) > max_length: logger.warning(f"Text too long ({len(text)} chars), truncating to {max_length} chars") text = text[:max_length] # 生成向量(使用 agenticx provider) # 注意:embed() 方法内部使用 asyncio.run(),在同步上下文中可以直接调用 # 如果在异步上下文中调用此同步方法,应该在 ThreadPoolExecutor 中运行 try: # 直接调用 embed(),它内部会使用 asyncio.run() 创建新的事件循环 # 这在同步上下文中可以正常工作 # 如果在异步上下文中,调用者应该在 ThreadPoolExecutor 中运行此方法 embeddings = self.provider_instance.embed([text]) embedding = embeddings[0] if embeddings else [] # 保存到缓存 self._save_to_cache(text, embedding) return embedding except Exception as e: logger.error(f"Embedding failed for text: {text[:100]}..., error: {e}") raise def embed_batch(self, texts: List[str]) -> List[List[float]]: """ 批量将文本转换为向量 Args: texts: 文本列表 Returns: 向量列表 """ if not texts: return [] # 检查缓存并分离需要处理的文本 embeddings_map = {} # {index: embedding} texts_to_embed = [] # [(index, text), ...] max_length = 6000 for idx, text in enumerate(texts): # 检查缓存 cached = self._get_from_cache(text) if cached is not None: embeddings_map[idx] = cached else: # 限制文本长度 if len(text) > max_length: logger.warning(f"Text too long ({len(text)} chars), truncating to {max_length} chars") text = text[:max_length] texts_to_embed.append((idx, text)) # 对未缓存的文本批量生成向量 # 注意:BailianEmbeddingProvider.embed() 内部已经会分批处理,不需要我们再次分批 if texts_to_embed: try: texts_list = [t[1] for t in texts_to_embed] # 直接调用 embed(),它内部会使用 asyncio.run() 创建新的事件循环 # BailianEmbeddingProvider 内部会根据 batch_size 自动分批处理 new_embeddings = self.provider_instance.embed(texts_list) # 保存到缓存并添加到结果 for (idx, text), embedding in zip(texts_to_embed, new_embeddings): self._save_to_cache(text, embedding) embeddings_map[idx] = embedding except Exception as e: logger.error(f"Batch embedding failed: {e}") raise # 按原始顺序返回结果 return [embeddings_map.get(i, []) for i in range(len(texts))] async def aembed_text(self, text: str) -> List[float]: """ 异步将文本转换为向量(推荐在异步上下文中使用) Args: text: 文本 Returns: 向量(List[float]) """ # 检查缓存 cached = self._get_from_cache(text) if cached is not None: return cached # 限制文本长度(避免超过模型限制) max_length = 6000 if len(text) > max_length: logger.warning(f"Text too long ({len(text)} chars), truncating to {max_length} chars") text = text[:max_length] # 使用异步接口,避免 asyncio.run() 的问题 try: embeddings = await self.provider_instance.aembed([text]) embedding = embeddings[0] if embeddings else [] # 保存到缓存 self._save_to_cache(text, embedding) return embedding except Exception as e: logger.error(f"Embedding failed for text: {text[:100]}..., error: {e}") raise async def aembed_batch(self, texts: List[str]) -> List[List[float]]: """ 异步批量将文本转换为向量(推荐在异步上下文中使用) Args: texts: 文本列表 Returns: 向量列表 """ if not texts: return [] # 检查缓存并分离需要处理的文本 embeddings_map = {} # {index: embedding} texts_to_embed = [] # [(index, text), ...] max_length = 6000 for idx, text in enumerate(texts): # 检查缓存 cached = self._get_from_cache(text) if cached is not None: embeddings_map[idx] = cached else: # 限制文本长度 if len(text) > max_length: logger.warning(f"Text too long ({len(text)} chars), truncating to {max_length} chars") text = text[:max_length] texts_to_embed.append((idx, text)) # 对未缓存的文本批量生成向量 # BailianEmbeddingProvider.aembed() 内部已经会分批处理 if texts_to_embed: try: texts_list = [t[1] for t in texts_to_embed] # 使用异步接口,避免 asyncio.run() 的问题 new_embeddings = await self.provider_instance.aembed(texts_list) # 保存到缓存并添加到结果 for (idx, text), embedding in zip(texts_to_embed, new_embeddings): self._save_to_cache(text, embedding) embeddings_map[idx] = embedding except Exception as e: logger.error(f"Batch embedding failed: {e}") raise # 按原始顺序返回结果 return [embeddings_map.get(i, []) for i in range(len(texts))] # 全局实例 _embedding_service: Optional[EmbeddingService] = None def get_embedding_service() -> EmbeddingService: """ 获取 Embedding 服务实例(单例模式) Returns: EmbeddingService 实例 """ global _embedding_service if _embedding_service is None: _embedding_service = EmbeddingService() return _embedding_service ================================================ FILE: backend/app/services/llm_service.py ================================================ """ LLM 服务封装 """ import logging from typing import Optional, Dict, Any, Union from agenticx import LiteLLMProvider, LLMResponse from agenticx.llms.bailian_provider import BailianProvider from ..core.config import settings logger = logging.getLogger(__name__) class LLMService: """ LLM 服务封装类 提供统一的 LLM 调用接口 """ def __init__( self, provider: str = None, model: str = None, temperature: float = None, max_tokens: int = None, api_key: str = None, base_url: str = None, ): """ 初始化 LLM 服务 Args: provider: 提供商(openai, anthropic, ollama) model: 模型名称 temperature: 温度参数 max_tokens: 最大token数 api_key: API密钥 base_url: 自定义 API 端点(用于第三方转发) """ self.provider_name = provider or settings.LLM_PROVIDER self.model = model or settings.LLM_MODEL self.temperature = temperature or settings.LLM_TEMPERATURE self.max_tokens = max_tokens or settings.LLM_MAX_TOKENS # 设置API密钥 if api_key: self.api_key = api_key elif self.provider_name == "bailian": self.api_key = settings.DASHSCOPE_API_KEY or settings.BAILIAN_API_KEY elif self.provider_name == "openai": self.api_key = settings.OPENAI_API_KEY elif self.provider_name == "deepseek": self.api_key = settings.DEEPSEEK_API_KEY elif self.provider_name == "kimi": self.api_key = settings.MOONSHOT_API_KEY elif self.provider_name == "zhipu": self.api_key = settings.ZHIPU_API_KEY elif self.provider_name == "anthropic": self.api_key = settings.ANTHROPIC_API_KEY else: self.api_key = None # 设置 Base URL(用于第三方 API 转发) if base_url: self.base_url = base_url elif self.provider_name == "bailian": self.base_url = settings.DASHSCOPE_BASE_URL elif self.provider_name == "openai": self.base_url = settings.OPENAI_BASE_URL elif self.provider_name == "deepseek": self.base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com/v1" elif self.provider_name == "kimi": self.base_url = settings.MOONSHOT_BASE_URL or "https://api.moonshot.cn/v1" elif self.provider_name == "zhipu": self.base_url = settings.ZHIPU_BASE_URL or "https://open.bigmodel.cn/api/paas/v4" elif self.provider_name == "anthropic": self.base_url = settings.ANTHROPIC_BASE_URL else: self.base_url = None # 创建 LLM 提供者 self.llm_provider = self._create_provider() def _create_provider(self) -> Union[LiteLLMProvider, BailianProvider]: """创建 LLM 提供者""" try: # 检测是否使用 Dashscope/Bailian API is_dashscope = ( self.base_url and "dashscope" in self.base_url.lower() ) or ( self.model and self.model.startswith("qwen") and self.base_url ) if is_dashscope: # 使用 BailianProvider(专门为百炼 API 设计) if not self.api_key: raise ValueError("API key is required for Bailian provider") provider = BailianProvider( model=self.model, api_key=self.api_key, base_url=self.base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1", temperature=self.temperature, timeout=float(settings.LLM_TIMEOUT), # 从配置读取超时时间 max_retries=2 # 减少重试次数,避免总耗时过长 ) logger.info(f"Initialized BailianProvider: {self.model}") return provider else: # 使用 LiteLLMProvider(通用 provider) provider_kwargs = { "model": self.model, "temperature": self.temperature, "max_tokens": self.max_tokens, "api_key": self.api_key, } # 如果设置了自定义 base_url,添加到配置中 if self.base_url: provider_kwargs["base_url"] = self.base_url logger.info(f"Using custom base URL: {self.base_url}") provider = LiteLLMProvider(**provider_kwargs) logger.info(f"Initialized LiteLLMProvider: {self.provider_name}/{self.model}") return provider except Exception as e: logger.error(f"Failed to initialize LLM provider: {e}") raise def generate( self, prompt: str, system_message: Optional[str] = None, **kwargs ) -> str: """ 生成文本 Args: prompt: 用户提示 system_message: 系统消息 **kwargs: 额外参数 Returns: 生成的文本 """ try: messages = [] if system_message: messages.append({"role": "system", "content": system_message}) messages.append({"role": "user", "content": prompt}) # 确保传递 max_tokens(如果 kwargs 中没有) if "max_tokens" not in kwargs: kwargs["max_tokens"] = self.max_tokens response: LLMResponse = self.llm_provider.generate( messages=messages, **kwargs ) return response.content except Exception as e: logger.error(f"LLM generation failed: {e}") raise def analyze_sentiment(self, text: str) -> Dict[str, Any]: """ 分析文本情感 Args: text: 待分析文本 Returns: 情感分析结果 """ system_message = """你是一个专业的金融新闻情感分析专家。 请分析给定新闻的情感倾向,判断其对相关股票的影响是利好、利空还是中性。 输出格式(JSON): { "sentiment": "positive/negative/neutral", "score": 0.0-1.0(情感强度), "confidence": 0.0-1.0(置信度), "reasoning": "分析理由" } """ prompt = f"""请分析以下新闻的情感倾向: {text[:1000]} 请严格按照JSON格式输出结果。""" try: response_text = self.generate(prompt, system_message) # 尝试解析JSON import json import re # 提取JSON部分 json_match = re.search(r'\{.*\}', response_text, re.DOTALL) if json_match: result = json.loads(json_match.group()) return result else: # 如果无法解析,返回默认值 return { "sentiment": "neutral", "score": 0.5, "confidence": 0.5, "reasoning": response_text } except Exception as e: logger.error(f"Sentiment analysis failed: {e}") return { "sentiment": "neutral", "score": 0.5, "confidence": 0.0, "reasoning": f"分析失败: {str(e)}" } def summarize(self, text: str, max_length: int = 200) -> str: """ 文本摘要 Args: text: 原始文本 max_length: 摘要最大长度 Returns: 摘要文本 """ system_message = f"""你是一个专业的金融新闻摘要专家。 请将给定的新闻内容总结为不超过{max_length}字的简洁摘要,保留关键信息。""" prompt = f"""请总结以下新闻: {text} 摘要:""" try: summary = self.generate(prompt, system_message, max_tokens=max_length) return summary.strip() except Exception as e: logger.error(f"Summarization failed: {e}") return text[:max_length] + "..." # 全局实例 _llm_service: Optional[LLMService] = None def get_llm_provider( provider: Optional[str] = None, model: Optional[str] = None ) -> Union[LiteLLMProvider, BailianProvider]: """ 获取 LLM 提供者实例(用于 AgenticX Agent) Args: provider: 可选的提供商名称(如 openai, bailian, ollama) model: 可选的模型名称 Returns: LiteLLMProvider 或 BailianProvider 实例 """ global _llm_service # 如果指定了 provider 或 model,创建新的实例 if provider or model: custom_service = LLMService(provider=provider, model=model) return custom_service.llm_provider # 否则使用全局实例 if _llm_service is None: _llm_service = LLMService() return _llm_service.llm_provider def get_llm_service() -> LLMService: """ 获取 LLM 服务实例 Returns: LLMService 实例 """ global _llm_service if _llm_service is None: _llm_service = LLMService() return _llm_service def create_custom_llm_provider( provider: Optional[str] = None, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, ) -> Union[LiteLLMProvider, BailianProvider]: """ 动态创建自定义 LLM provider(用于模型切换) Args: provider: 厂商名称(bailian, openai, deepseek, kimi, zhipu) model: 模型名称 temperature: 温度参数 max_tokens: 最大token数 api_key: API Key(可选,优先从settings读取) base_url: Base URL(可选,优先从settings读取) Returns: LLM provider 实例 Examples: >>> llm = create_custom_llm_provider('bailian', 'qwen-max') >>> llm = create_custom_llm_provider('openai', 'gpt-4') >>> llm = create_custom_llm_provider('zhipu', 'glm-4') """ _provider = provider or settings.LLM_PROVIDER _model = model or settings.LLM_MODEL _temperature = temperature if temperature is not None else settings.LLM_TEMPERATURE _max_tokens = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS logger.info(f"Creating custom LLM provider: {_provider}/{_model}") try: if _provider == 'bailian': # 使用阿里云百炼(通过 OpenAI 兼容接口) _api_key = api_key or settings.DASHSCOPE_API_KEY or settings.BAILIAN_API_KEY if not _api_key: raise ValueError("DASHSCOPE_API_KEY or BAILIAN_API_KEY is required for bailian provider") _base_url = base_url or settings.DASHSCOPE_BASE_URL return BailianProvider( model=_model, api_key=_api_key, base_url=_base_url, access_key_id=settings.BAILIAN_ACCESS_KEY_ID, access_key_secret=settings.BAILIAN_ACCESS_KEY_SECRET, agent_code=settings.BAILIAN_AGENT_CODE, region_id=settings.BAILIAN_REGION_ID, temperature=_temperature, max_tokens=_max_tokens, timeout=float(settings.LLM_TIMEOUT), # 从配置读取超时时间 max_retries=2 # 减少重试次数,避免总耗时过长 ) elif _provider == 'openai': # 使用 OpenAI _api_key = api_key or settings.OPENAI_API_KEY if not _api_key: raise ValueError("OPENAI_API_KEY is required for openai provider") _base_url = base_url or settings.OPENAI_BASE_URL return LiteLLMProvider( provider="openai", model=_model, api_key=_api_key, base_url=_base_url, temperature=_temperature, max_tokens=_max_tokens ) elif _provider == 'deepseek': # 使用 DeepSeek(通过 OpenAI 兼容接口) _api_key = api_key or settings.DEEPSEEK_API_KEY if not _api_key: raise ValueError("DEEPSEEK_API_KEY is required for deepseek provider") _base_url = base_url or settings.DEEPSEEK_BASE_URL or 'https://api.deepseek.com/v1' return LiteLLMProvider( provider="openai", model=_model, api_key=_api_key, base_url=_base_url, temperature=_temperature, max_tokens=_max_tokens ) elif _provider == 'kimi': # 使用 Kimi (Moonshot) _api_key = api_key or settings.MOONSHOT_API_KEY if not _api_key: raise ValueError("MOONSHOT_API_KEY is required for kimi provider") _base_url = base_url or settings.MOONSHOT_BASE_URL or 'https://api.moonshot.cn/v1' return LiteLLMProvider( provider="openai", model=_model, api_key=_api_key, base_url=_base_url, temperature=_temperature, max_tokens=_max_tokens ) elif _provider == 'zhipu': # 使用智谱 AI _api_key = api_key or settings.ZHIPU_API_KEY if not _api_key: raise ValueError("ZHIPU_API_KEY is required for zhipu provider") _base_url = base_url or settings.ZHIPU_BASE_URL or 'https://open.bigmodel.cn/api/paas/v4' return LiteLLMProvider( provider="openai", model=_model, api_key=_api_key, base_url=_base_url, temperature=_temperature, max_tokens=_max_tokens ) else: logger.warning(f"Unsupported provider: {_provider}, falling back to default") return get_llm_provider() except ValueError as e: logger.error(f"Configuration error: {e}") raise except Exception as e: logger.error(f"Failed to create custom LLM provider: {e}", exc_info=True) # 降级到默认 provider return get_llm_provider() ================================================ FILE: backend/app/services/stock_data_service.py ================================================ """ 股票数据服务 - 使用 akshare 获取真实股票数据 """ import logging from datetime import datetime, timedelta from typing import List, Optional, Dict, Any from functools import lru_cache import asyncio logger = logging.getLogger(__name__) # 尝试导入 akshare try: import akshare as ak import pandas as pd AKSHARE_AVAILABLE = True except ImportError: AKSHARE_AVAILABLE = False logger.warning("akshare not installed, using mock data") class StockDataService: """股票数据服务 - 封装 akshare 接口""" # 缓存过期时间(秒) CACHE_TTL = 300 # 5分钟 CACHE_TTL_MINUTE = 60 # 分钟级数据缓存1分钟 # 股票代码前缀映射 MARKET_PREFIX = { "sh": "6", # 上海 60xxxx "sz": "0", # 深圳 00xxxx, 30xxxx "sz3": "3", # 创业板 30xxxx } # 周期映射 PERIOD_MAP = { "1m": "1", # 1分钟 "5m": "5", # 5分钟 "15m": "15", # 15分钟 "30m": "30", # 30分钟 "60m": "60", # 60分钟/1小时 "1h": "60", # 1小时(别名) "daily": "daily", # 日线 "1d": "daily", # 日线(别名) } def __init__(self): self._cache: Dict[str, tuple] = {} # {key: (data, timestamp)} def _normalize_code(self, stock_code: str) -> str: """ 标准化股票代码,返回纯数字代码 支持格式: SH600519, sh600519, 600519 """ code = stock_code.upper().strip() if code.startswith("SH") or code.startswith("SZ"): return code[2:] return code def _get_symbol(self, stock_code: str) -> str: """ 获取 akshare 使用的股票代码格式 akshare stock_zh_a_hist 需要纯数字代码 """ return self._normalize_code(stock_code) def _is_cache_valid(self, key: str, ttl: int = None) -> bool: """检查缓存是否有效""" if key not in self._cache: return False _, timestamp = self._cache[key] cache_ttl = ttl if ttl is not None else self.CACHE_TTL # 修复bug: 使用 total_seconds() 而不是 seconds # seconds 只返回秒数部分(0-86399),不包括天数 return (datetime.now() - timestamp).total_seconds() < cache_ttl def _get_cached(self, key: str, ttl: int = None) -> Optional[Any]: """获取缓存数据""" if self._is_cache_valid(key, ttl): return self._cache[key][0] # 清理过期缓存 if key in self._cache: del self._cache[key] return None def _set_cache(self, key: str, data: Any): """设置缓存""" self._cache[key] = (data, datetime.now()) def clear_cache(self, pattern: str = None): """ 清除缓存 Args: pattern: 可选的缓存键模式,如果提供则只清除匹配的缓存 """ if pattern: keys_to_delete = [k for k in self._cache.keys() if pattern in k] for key in keys_to_delete: del self._cache[key] logger.info(f"🧹 Cleared {len(keys_to_delete)} cache entries matching pattern: {pattern}") else: count = len(self._cache) self._cache.clear() logger.info(f"🧹 Cleared all {count} cache entries") async def get_kline_data( self, stock_code: str, period: str = "daily", # daily, 1m, 5m, 15m, 30m, 60m limit: int = 90, # 数据条数 adjust: str = "qfq" # qfq=前复权, hfq=后复权, ""=不复权 ) -> List[Dict[str, Any]]: """ 获取K线数据(支持日线和分钟级数据) Args: stock_code: 股票代码 period: 周期 (daily, 1m, 5m, 15m, 30m, 60m) limit: 返回数据条数 adjust: 复权类型(仅日线有效) Returns: K线数据列表,每条包含: timestamp, open, high, low, close, volume, turnover """ # 标准化周期 period_key = self.PERIOD_MAP.get(period, period) cache_key = f"kline:{stock_code}:{period}:{limit}:{adjust}" # 根据周期使用不同的缓存TTL:日线5分钟,分钟级1分钟 cache_ttl = self.CACHE_TTL if period_key == "daily" else self.CACHE_TTL_MINUTE cached = self._get_cached(cache_key, ttl=cache_ttl) if cached: latest_date = cached[-1].get('date', 'unknown') if cached else 'empty' logger.info(f"🔵 Cache hit for {cache_key}, latest date: {latest_date}, count: {len(cached)}") return cached logger.info(f"🔴 Cache miss for {cache_key}, fetching fresh data...") if not AKSHARE_AVAILABLE: logger.warning("akshare not available, returning mock data") return self._generate_mock_kline(stock_code, limit) try: symbol = self._get_symbol(stock_code) loop = asyncio.get_event_loop() if period_key == "daily": # 日线数据 kline_data = await self._fetch_daily_kline(symbol, limit, adjust, loop) else: # 分钟级数据 kline_data = await self._fetch_minute_kline(symbol, period_key, limit, loop) if not kline_data: logger.warning(f"⚠️ No valid data after parsing for {stock_code} period={period}, using mock data") return self._generate_mock_kline(stock_code, limit) # 记录最新数据的日期和价格,便于调试 latest = kline_data[-1] logger.info(f"✅ Successfully fetched {len(kline_data)} kline records for {stock_code} period={period}, latest: {latest['date']}, close: {latest['close']}") self._set_cache(cache_key, kline_data) return kline_data except Exception as e: logger.error(f"❌ Failed to fetch kline data for {stock_code}: {type(e).__name__}: {e}", exc_info=True) # 只在某些特定错误时返回mock数据,其他错误应该抛出 if "NaTType" in str(e) or "timestamp" in str(e).lower(): logger.warning(f"Data parsing error, this should not happen after fix. Returning empty list.") return [] # 网络错误或API错误才返回mock数据 return self._generate_mock_kline(stock_code, limit) async def _fetch_daily_kline( self, symbol: str, limit: int, adjust: str, loop ) -> List[Dict[str, Any]]: """获取日线数据""" end_date = datetime.now() # 多获取一些天数,确保有足够数据(考虑周末和节假日,约1个交易日=1.5个自然日) # limit * 1.6 能确保获取到足够的交易日数据 start_date = end_date - timedelta(days=int(limit * 1.6)) logger.info(f"📊 Calling akshare API: symbol={symbol}, start={start_date.strftime('%Y%m%d')}, end={end_date.strftime('%Y%m%d')}, adjust={adjust}") df = await loop.run_in_executor( None, lambda: ak.stock_zh_a_hist( symbol=symbol, start_date=start_date.strftime("%Y%m%d"), end_date=end_date.strftime("%Y%m%d"), adjust=adjust ) ) logger.info(f"✅ Akshare returned {len(df) if df is not None and not df.empty else 0} rows") if df is None or df.empty: return [] # 清理数据:移除日期为NaT的行 df = df.dropna(subset=['日期']) # 只取最近 limit 条数据 df = df.tail(limit) # 转换为标准格式 kline_data = [] for _, row in df.iterrows(): try: # 处理日期 date_val = row['日期'] if pd.isna(date_val): logger.warning(f"Skipping row with NaT date") continue if isinstance(date_val, str): dt = datetime.strptime(date_val, "%Y-%m-%d") date_str = date_val else: dt = pd.to_datetime(date_val) if pd.isna(dt): logger.warning(f"Skipping row with invalid date") continue date_str = dt.strftime("%Y-%m-%d") timestamp = int(dt.timestamp() * 1000) kline_data.append({ "timestamp": timestamp, "date": date_str, "open": float(row['开盘']), "high": float(row['最高']), "low": float(row['最低']), "close": float(row['收盘']), "volume": int(row['成交量']), "turnover": float(row.get('成交额', 0)), "change_percent": float(row.get('涨跌幅', 0)), "change_amount": float(row.get('涨跌额', 0)), "amplitude": float(row.get('振幅', 0)), "turnover_rate": float(row.get('换手率', 0)), }) except Exception as e: logger.warning(f"Failed to parse row, skipping: {e}") continue # 记录数据范围 if kline_data: logger.info(f"✅ Parsed {len(kline_data)} valid records, date range: {kline_data[0]['date']} to {kline_data[-1]['date']}") return kline_data async def _fetch_minute_kline( self, symbol: str, period: str, # "1", "5", "15", "30", "60" limit: int, loop ) -> List[Dict[str, Any]]: """获取分钟级数据""" df = await loop.run_in_executor( None, lambda: ak.stock_zh_a_hist_min_em( symbol=symbol, period=period, adjust="" ) ) if df is None or df.empty: return [] # 清理数据:移除时间为NaT的行 df = df.dropna(subset=['时间']) # 只取最近 limit 条数据 df = df.tail(limit) # 转换为标准格式 kline_data = [] for _, row in df.iterrows(): try: # 处理时间 time_val = row['时间'] if pd.isna(time_val): logger.warning(f"Skipping row with NaT time") continue time_str = str(time_val) try: dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") except: dt = pd.to_datetime(time_val) if pd.isna(dt): logger.warning(f"Skipping row with invalid time") continue time_str = dt.strftime("%Y-%m-%d %H:%M:%S") timestamp = int(dt.timestamp() * 1000) kline_data.append({ "timestamp": timestamp, "date": time_str, "open": float(row['开盘']), "high": float(row['最高']), "low": float(row['最低']), "close": float(row['收盘']), "volume": int(row['成交量']), "turnover": float(row.get('成交额', 0)), "change_percent": 0, # 分钟数据可能没有涨跌幅 "change_amount": 0, "amplitude": 0, "turnover_rate": 0, }) except Exception as e: logger.warning(f"Failed to parse minute row, skipping: {e}") continue # 记录数据范围 if kline_data: logger.info(f"✅ Parsed {len(kline_data)} valid minute records, time range: {kline_data[0]['date']} to {kline_data[-1]['date']}") return kline_data async def get_realtime_quote(self, stock_code: str) -> Optional[Dict[str, Any]]: """ 获取实时行情 Returns: 实时行情数据 """ cache_key = f"realtime:{stock_code}" cached = self._get_cached(cache_key) if cached: return cached if not AKSHARE_AVAILABLE: return None try: symbol = self._get_symbol(stock_code) loop = asyncio.get_event_loop() df = await loop.run_in_executor( None, lambda: ak.stock_zh_a_spot_em() ) if df is None or df.empty: return None # 根据股票代码筛选 row = df[df['代码'] == symbol] if row.empty: return None row = row.iloc[0] quote = { "code": symbol, "name": row.get('名称', ''), "price": float(row.get('最新价', 0)), "change_percent": float(row.get('涨跌幅', 0)), "change_amount": float(row.get('涨跌额', 0)), "volume": int(row.get('成交量', 0)), "turnover": float(row.get('成交额', 0)), "high": float(row.get('最高', 0)), "low": float(row.get('最低', 0)), "open": float(row.get('今开', 0)), "prev_close": float(row.get('昨收', 0)), } self._set_cache(cache_key, quote) return quote except Exception as e: logger.error(f"Failed to fetch realtime quote for {stock_code}: {e}") return None async def search_stocks( self, keyword: str, limit: int = 20 ) -> List[Dict[str, Any]]: """ 搜索股票(通过代码或名称模糊匹配) Args: keyword: 搜索关键词 limit: 返回数量限制 Returns: 股票列表 """ cache_key = f"search:{keyword}:{limit}" cached = self._get_cached(cache_key) if cached: return cached if not AKSHARE_AVAILABLE: return self._get_mock_stock_list(keyword, limit) try: loop = asyncio.get_event_loop() # 获取全部 A 股实时行情(包含代码和名称) df = await loop.run_in_executor( None, lambda: ak.stock_zh_a_spot_em() ) if df is None or df.empty: return self._get_mock_stock_list(keyword, limit) # 模糊匹配代码或名称 keyword_upper = keyword.upper() mask = ( df['代码'].str.contains(keyword_upper, na=False) | df['名称'].str.contains(keyword, na=False) ) matched = df[mask].head(limit) results = [] for _, row in matched.iterrows(): code = str(row['代码']) # 确定市场前缀 if code.startswith('6'): full_code = f"SH{code}" elif code.startswith('0') or code.startswith('3'): full_code = f"SZ{code}" else: full_code = code results.append({ "code": code, "name": str(row['名称']), "full_code": full_code, "price": float(row.get('最新价', 0)) if pd.notna(row.get('最新价')) else 0, "change_percent": float(row.get('涨跌幅', 0)) if pd.notna(row.get('涨跌幅')) else 0, }) self._set_cache(cache_key, results) return results except Exception as e: logger.error(f"Failed to search stocks: {e}") return self._get_mock_stock_list(keyword, limit) def _get_mock_stock_list(self, keyword: str, limit: int) -> List[Dict[str, Any]]: """返回模拟股票列表""" mock_stocks = [ {"code": "600519", "name": "贵州茅台", "full_code": "SH600519", "price": 1420.0, "change_percent": 0.5}, {"code": "000001", "name": "平安银行", "full_code": "SZ000001", "price": 12.0, "change_percent": -0.3}, {"code": "601318", "name": "中国平安", "full_code": "SH601318", "price": 45.0, "change_percent": 0.2}, {"code": "000858", "name": "五粮液", "full_code": "SZ000858", "price": 150.0, "change_percent": 1.1}, {"code": "002594", "name": "比亚迪", "full_code": "SZ002594", "price": 250.0, "change_percent": -0.8}, {"code": "600036", "name": "招商银行", "full_code": "SH600036", "price": 35.0, "change_percent": 0.1}, {"code": "601166", "name": "兴业银行", "full_code": "SH601166", "price": 18.0, "change_percent": 0.3}, {"code": "000333", "name": "美的集团", "full_code": "SZ000333", "price": 65.0, "change_percent": 0.6}, {"code": "002415", "name": "海康威视", "full_code": "SZ002415", "price": 32.0, "change_percent": -0.5}, {"code": "600276", "name": "恒瑞医药", "full_code": "SH600276", "price": 42.0, "change_percent": 0.4}, ] keyword_lower = keyword.lower() filtered = [ s for s in mock_stocks if keyword_lower in s["code"].lower() or keyword_lower in s["name"].lower() ] return filtered[:limit] async def get_stock_info(self, stock_code: str) -> Optional[Dict[str, Any]]: """ 获取股票基本信息 """ if not AKSHARE_AVAILABLE: return None try: symbol = self._get_symbol(stock_code) loop = asyncio.get_event_loop() df = await loop.run_in_executor( None, lambda: ak.stock_individual_info_em(symbol=symbol) ) if df is None or df.empty: return None # 转换为字典 info = {} for _, row in df.iterrows(): info[row['item']] = row['value'] return info except Exception as e: logger.error(f"Failed to fetch stock info for {stock_code}: {e}") return None def _generate_mock_kline(self, stock_code: str, days: int) -> List[Dict[str, Any]]: """ 生成模拟K线数据(当 akshare 不可用时使用) """ import random # 根据股票代码设定基准价格 base_prices = { "600519": 1500.0, # 贵州茅台 "000001": 12.0, # 平安银行 "601318": 45.0, # 中国平安 "000858": 150.0, # 五粮液 "002594": 250.0, # 比亚迪 } code = self._normalize_code(stock_code) base_price = base_prices.get(code, 50.0) current_price = base_price kline_data = [] for i in range(days): dt = datetime.now() - timedelta(days=days - i - 1) # 跳过周末 if dt.weekday() >= 5: continue timestamp = int(dt.timestamp() * 1000) date_str = dt.strftime("%Y-%m-%d") # 随机波动 change_percent = random.uniform(-3, 3) open_price = current_price close_price = current_price * (1 + change_percent / 100) high_price = max(open_price, close_price) * (1 + random.uniform(0, 1.5) / 100) low_price = min(open_price, close_price) * (1 - random.uniform(0, 1.5) / 100) volume = random.randint(50000, 500000) turnover = volume * close_price kline_data.append({ "timestamp": timestamp, "date": date_str, "open": round(open_price, 2), "high": round(high_price, 2), "low": round(low_price, 2), "close": round(close_price, 2), "volume": volume, "turnover": round(turnover, 2), "change_percent": round(change_percent, 2), "change_amount": round(close_price - open_price, 2), "amplitude": round((high_price - low_price) / open_price * 100, 2), "turnover_rate": round(random.uniform(0.5, 5), 2), }) current_price = close_price return kline_data[-days:] if len(kline_data) > days else kline_data async def get_financial_indicators(self, stock_code: str) -> Optional[Dict[str, Any]]: """ 获取股票财务指标(用于辩论分析) 包括:PE、PB、ROE、净利润增长率等 Args: stock_code: 股票代码 Returns: 财务指标字典 """ cache_key = f"financial:{stock_code}" cached = self._get_cached(cache_key, ttl=3600) # 财务数据缓存1小时 if cached: return cached if not AKSHARE_AVAILABLE: logger.warning("akshare not available, returning mock financial data") return self._get_mock_financial_indicators(stock_code) try: symbol = self._get_symbol(stock_code) loop = asyncio.get_event_loop() # 方法1:从实时行情获取基础估值数据 spot_df = await loop.run_in_executor( None, lambda: ak.stock_zh_a_spot_em() ) financial_data = {} if spot_df is not None and not spot_df.empty: row = spot_df[spot_df['代码'] == symbol] if not row.empty: row = row.iloc[0] financial_data.update({ "pe_ratio": self._safe_float(row.get('市盈率-动态')), "pb_ratio": self._safe_float(row.get('市净率')), "total_market_value": self._safe_float(row.get('总市值')), "circulating_market_value": self._safe_float(row.get('流通市值')), "turnover_rate": self._safe_float(row.get('换手率')), "volume_ratio": self._safe_float(row.get('量比')), "amplitude": self._safe_float(row.get('振幅')), "price_52w_high": self._safe_float(row.get('52周最高')), "price_52w_low": self._safe_float(row.get('52周最低')), }) # 方法2:尝试获取更详细的财务摘要 try: financial_abstract = await loop.run_in_executor( None, lambda: ak.stock_financial_abstract_ths(symbol=symbol) ) if financial_abstract is not None and not financial_abstract.empty: # 取最新一期数据 latest = financial_abstract.iloc[0] if len(financial_abstract) > 0 else None if latest is not None: financial_data.update({ "roe": self._safe_float(latest.get('净资产收益率')), "gross_profit_margin": self._safe_float(latest.get('毛利率')), "net_profit_margin": self._safe_float(latest.get('净利率')), "debt_ratio": self._safe_float(latest.get('资产负债率')), "revenue_yoy": self._safe_float(latest.get('营业总收入同比增长率')), "profit_yoy": self._safe_float(latest.get('净利润同比增长率')), }) except Exception as e: logger.debug(f"Failed to fetch financial abstract for {stock_code}: {e}") if financial_data: self._set_cache(cache_key, financial_data) return financial_data return self._get_mock_financial_indicators(stock_code) except Exception as e: logger.error(f"Failed to fetch financial indicators for {stock_code}: {e}") return self._get_mock_financial_indicators(stock_code) def _safe_float(self, value, default=None) -> Optional[float]: """安全转换为浮点数""" if value is None or (isinstance(value, float) and pd.isna(value)): return default try: return float(value) except (ValueError, TypeError): return default def _get_mock_financial_indicators(self, stock_code: str) -> Dict[str, Any]: """返回模拟财务指标""" return { "pe_ratio": 25.5, "pb_ratio": 3.2, "roe": 15.8, "total_market_value": 100000000000, # 1000亿 "circulating_market_value": 80000000000, "turnover_rate": 2.5, "gross_profit_margin": 45.2, "net_profit_margin": 22.1, "debt_ratio": 35.5, "revenue_yoy": 12.5, "profit_yoy": 18.3, } async def get_fund_flow(self, stock_code: str, days: int = 5) -> Optional[Dict[str, Any]]: """ 获取个股资金流向(用于辩论分析) 包括:主力资金净流入、散户资金流向等 Args: stock_code: 股票代码 days: 获取最近几天的数据 Returns: 资金流向数据 """ cache_key = f"fund_flow:{stock_code}:{days}" cached = self._get_cached(cache_key, ttl=300) # 资金流向缓存5分钟 if cached: return cached if not AKSHARE_AVAILABLE: logger.warning("akshare not available, returning mock fund flow data") return self._get_mock_fund_flow(stock_code) try: symbol = self._get_symbol(stock_code) loop = asyncio.get_event_loop() # 获取个股资金流向 df = await loop.run_in_executor( None, lambda: ak.stock_individual_fund_flow(stock=symbol, market="sh" if symbol.startswith("6") else "sz") ) if df is None or df.empty: return self._get_mock_fund_flow(stock_code) # 取最近几天的数据 df = df.head(days) # 汇总数据 total_main_net = 0 total_super_large_net = 0 total_large_net = 0 total_medium_net = 0 total_small_net = 0 daily_flows = [] for _, row in df.iterrows(): main_net = self._safe_float(row.get('主力净流入-净额'), 0) super_large_net = self._safe_float(row.get('超大单净流入-净额'), 0) large_net = self._safe_float(row.get('大单净流入-净额'), 0) medium_net = self._safe_float(row.get('中单净流入-净额'), 0) small_net = self._safe_float(row.get('小单净流入-净额'), 0) total_main_net += main_net total_super_large_net += super_large_net total_large_net += large_net total_medium_net += medium_net total_small_net += small_net daily_flows.append({ "date": str(row.get('日期', '')), "main_net": main_net, "super_large_net": super_large_net, "large_net": large_net, "medium_net": medium_net, "small_net": small_net, }) fund_flow_data = { "period_days": days, "total_main_net": total_main_net, "total_super_large_net": total_super_large_net, "total_large_net": total_large_net, "total_medium_net": total_medium_net, "total_small_net": total_small_net, "main_flow_trend": "流入" if total_main_net > 0 else "流出", "daily_flows": daily_flows, } self._set_cache(cache_key, fund_flow_data) return fund_flow_data except Exception as e: logger.error(f"Failed to fetch fund flow for {stock_code}: {e}") return self._get_mock_fund_flow(stock_code) def _get_mock_fund_flow(self, stock_code: str) -> Dict[str, Any]: """返回模拟资金流向数据""" return { "period_days": 5, "total_main_net": 50000000, # 5000万 "total_super_large_net": 30000000, "total_large_net": 20000000, "total_medium_net": -5000000, "total_small_net": -10000000, "main_flow_trend": "流入", "daily_flows": [], } async def get_debate_context(self, stock_code: str) -> Dict[str, Any]: """ 获取用于辩论的综合上下文数据 整合财务指标、资金流向、实时行情等信息 Args: stock_code: 股票代码 Returns: 综合上下文数据 """ # 并行获取多个数据源 realtime_task = self.get_realtime_quote(stock_code) financial_task = self.get_financial_indicators(stock_code) fund_flow_task = self.get_fund_flow(stock_code, days=5) realtime, financial, fund_flow = await asyncio.gather( realtime_task, financial_task, fund_flow_task, return_exceptions=True ) # 处理异常 if isinstance(realtime, Exception): logger.error(f"Failed to get realtime quote: {realtime}") realtime = None if isinstance(financial, Exception): logger.error(f"Failed to get financial indicators: {financial}") financial = None if isinstance(fund_flow, Exception): logger.error(f"Failed to get fund flow: {fund_flow}") fund_flow = None # 生成文本摘要 context_parts = [] if realtime: context_parts.append( f"【实时行情】当前价: {realtime.get('price', 'N/A')}元, " f"涨跌幅: {realtime.get('change_percent', 'N/A')}%, " f"成交量: {realtime.get('volume', 'N/A')}" ) if financial: pe = financial.get('pe_ratio') pb = financial.get('pb_ratio') roe = financial.get('roe') profit_yoy = financial.get('profit_yoy') context_parts.append( f"【估值指标】PE: {pe if pe else 'N/A'}, PB: {pb if pb else 'N/A'}, " f"ROE: {roe if roe else 'N/A'}%, 净利润同比: {profit_yoy if profit_yoy else 'N/A'}%" ) if fund_flow: main_net = fund_flow.get('total_main_net', 0) main_net_str = f"{main_net/10000:.2f}万" if abs(main_net) < 100000000 else f"{main_net/100000000:.2f}亿" context_parts.append( f"【资金流向】近{fund_flow.get('period_days', 5)}日主力净{fund_flow.get('main_flow_trend', 'N/A')}: {main_net_str}" ) return { "realtime": realtime, "financial": financial, "fund_flow": fund_flow, "summary": "\n".join(context_parts) if context_parts else "暂无额外数据", } # 单例实例 stock_data_service = StockDataService() ================================================ FILE: backend/app/storage/__init__.py ================================================ """ 存储模块 """ from .vector_storage import VectorStorage __all__ = ["VectorStorage"] ================================================ FILE: backend/app/storage/vector_storage.py ================================================ """ 向量存储封装 - 直接使用 agenticx.storage.vectordb_storages.milvus.MilvusStorage 提供简单的兼容性接口,充分利用 base 类的便利方法 """ import logging import asyncio from typing import List, Dict, Any, Optional from ..core.config import settings from agenticx.storage.vectordb_storages.milvus import MilvusStorage from agenticx.storage.vectordb_storages.base import VectorRecord, VectorDBQuery logger = logging.getLogger(__name__) class VectorStorage: """ Milvus 向量存储封装类 直接使用 agenticx.storage.vectordb_storages.milvus.MilvusStorage 提供简单的兼容性接口,只做必要的接口转换 """ def __init__( self, host: str = None, port: int = None, collection_name: str = None, dim: int = None, ): """初始化向量存储""" self.host = host or settings.MILVUS_HOST self.port = port or settings.MILVUS_PORT self.collection_name = collection_name or settings.MILVUS_COLLECTION_NAME self.dim = dim or settings.MILVUS_DIM # 直接使用 agenticx MilvusStorage self.milvus_storage = MilvusStorage( dimension=self.dim, host=self.host, port=self.port, collection_name=self.collection_name ) logger.info(f"Initialized VectorStorage using MilvusStorage: {self.collection_name}, dim={self.dim}") def _call_add_async(self, records: List[VectorRecord], timeout: int = 15) -> None: """辅助方法:在同步上下文中调用异步 add() 方法""" try: loop = asyncio.get_running_loop() future = asyncio.run_coroutine_threadsafe(self.milvus_storage.add(records), loop) try: future.result(timeout=timeout) except Exception: logger.warning(f"Vector insert timeout ({timeout}s), but data may have been inserted") except RuntimeError: try: asyncio.run(asyncio.wait_for(self.milvus_storage.add(records), timeout=timeout)) except asyncio.TimeoutError: logger.warning(f"Vector insert timeout ({timeout}s), but data may have been inserted") def connect(self): """连接到 Milvus(兼容性方法)""" # MilvusStorage 在初始化时已经连接 pass def create_collection(self, drop_existing: bool = False): """创建集合(兼容性方法)""" # MilvusStorage 在初始化时已经创建集合 if drop_existing: self.milvus_storage.clear() self.milvus_storage = MilvusStorage( dimension=self.dim, host=self.host, port=self.port, collection_name=self.collection_name ) def load_collection(self): """加载集合到内存(兼容性方法)""" self.milvus_storage.load() def store_embedding( self, news_id: int, embedding: List[float], text: str ) -> int: """存储单个向量(兼容性接口)""" record = VectorRecord( id=str(news_id), vector=embedding, payload={"news_id": news_id, "text": text[:65535]} ) self._call_add_async([record], timeout=15) return news_id def store_embeddings_batch( self, news_ids: List[int], embeddings: List[List[float]], texts: List[str] ) -> List[int]: """批量存储向量(兼容性接口)""" records = [ VectorRecord( id=str(news_id), vector=embedding, payload={"news_id": news_id, "text": text[:65535]} ) for news_id, embedding, text in zip(news_ids, embeddings, texts) ] self._call_add_async(records, timeout=30) return news_ids def search_similar( self, query_embedding: List[float], top_k: int = 10, filter_expr: Optional[str] = None ) -> List[Dict[str, Any]]: """搜索相似向量(兼容性接口)""" query = VectorDBQuery(query_vector=query_embedding, top_k=top_k) results = self.milvus_storage.query(query) # 格式化结果 formatted_results = [] for result in results: payload = result.record.payload or {} news_id = payload.get("news_id") if news_id is None: try: news_id = int(result.record.id) except (ValueError, TypeError): continue # 简单的过滤支持 if filter_expr and "news_id" in filter_expr: import re match = re.search(r'news_id\s*==\s*(\d+)', filter_expr) if match and news_id != int(match.group(1)): continue formatted_results.append({ "id": result.record.id, "news_id": news_id, "text": payload.get("text", ""), "distance": result.similarity, "score": 1 / (1 + result.similarity) if result.similarity > 0 else 1.0, }) return formatted_results def delete_by_news_id(self, news_id: int): """删除指定新闻的向量(兼容性接口)""" self.milvus_storage.delete([str(news_id)]) def verify_insert(self, news_id: int, wait_for_flush: bool = True) -> bool: """验证数据是否成功插入(兼容性接口)""" if wait_for_flush: import time time.sleep(3) # 使用 base 类的 get_payloads_by_vector 方法 zero_vector = [0.0] * self.dim payloads = self.milvus_storage.get_payloads_by_vector(zero_vector, top_k=1000) for payload in payloads: if payload and payload.get("news_id") == news_id: return True return False def get_stats(self) -> Dict[str, Any]: """获取集合统计信息(兼容性接口) 注意:如果 num_entities 为 0,会通过实际查询来获取真实数量 (因为 flush 失败时 num_entities 可能不准确) """ status = self.milvus_storage.status() num_entities = status.vector_count # 如果 num_entities 为 0,尝试通过查询获取真实数量 # 这可以解决 flush 失败导致统计不准确的问题 if num_entities == 0: try: from agenticx.storage.vectordb_storages.base import VectorDBQuery # 使用零向量查询,设置一个较大的 top_k 来获取实际数量 zero_vector = [0.0] * status.vector_dim query = VectorDBQuery(query_vector=zero_vector, top_k=10000) # 最多查询10000条 results = self.milvus_storage.query(query) if results: num_entities = len(results) # 如果返回了10000条,说明可能还有更多,标记为近似值 if len(results) >= 10000: num_entities = f"{len(results)}+ (近似值,实际可能更多)" except Exception as e: logger.debug(f"无法通过查询获取真实数量: {e}") # 如果查询失败,仍然使用 num_entities=0 return { "num_entities": num_entities, "collection_name": self.collection_name, "dim": status.vector_dim, } def disconnect(self): """断开连接(兼容性方法)""" self.milvus_storage.close() @property def collection(self): """兼容性属性:返回底层的 Milvus collection 对象""" return self.milvus_storage.collection # 全局实例 _vector_storage: Optional[VectorStorage] = None def get_vector_storage() -> VectorStorage: """获取向量存储实例(单例模式)""" global _vector_storage if _vector_storage is None: _vector_storage = VectorStorage() return _vector_storage ================================================ FILE: backend/app/tasks/__init__.py ================================================ """ Celery 任务模块 """ from .crawl_tasks import realtime_crawl_task, cold_start_crawl_task __all__ = [ "realtime_crawl_task", "cold_start_crawl_task", ] ================================================ FILE: backend/app/tasks/crawl_tasks.py ================================================ """ Celery 爬取任务 - Phase 2: 实时监控升级版 + 多源支持 """ import logging import json from datetime import datetime, timedelta from typing import List, Dict, Any from sqlalchemy import select, create_engine, text from sqlalchemy.orm import Session import asyncio from ..core.celery_app import celery_app from ..core.config import settings from ..core.redis_client import redis_client from ..models.crawl_task import CrawlTask, CrawlMode, TaskStatus from ..models.news import News from ..tools import ( SinaCrawlerTool, TencentCrawlerTool, JwviewCrawlerTool, EeoCrawlerTool, CaijingCrawlerTool, Jingji21CrawlerTool, NbdCrawlerTool, YicaiCrawlerTool, Netease163CrawlerTool, EastmoneyCrawlerTool, bochaai_search, NewsItem, ) from ..tools.crawler_enhanced import EnhancedCrawler, crawl_url logger = logging.getLogger(__name__) def clean_text_for_db(text: str) -> str: """ 清理文本中不适合存入数据库的字符 PostgreSQL 不允许在文本字段中存储 NUL 字符 (\x00) Args: text: 原始文本 Returns: 清理后的文本 """ if text is None: return None if not isinstance(text, str): return text # 移除 NUL 字符 return text.replace('\x00', '').replace('\0', '') def get_crawler_tool(source: str): """ 爬虫工厂函数 Args: source: 新闻源名称 Returns: 对应的爬虫实例 """ crawlers = { "sina": SinaCrawlerTool, "tencent": TencentCrawlerTool, "jwview": JwviewCrawlerTool, "eeo": EeoCrawlerTool, "caijing": CaijingCrawlerTool, "jingji21": Jingji21CrawlerTool, "nbd": NbdCrawlerTool, "yicai": YicaiCrawlerTool, "163": Netease163CrawlerTool, "eastmoney": EastmoneyCrawlerTool, } crawler_class = crawlers.get(source) if not crawler_class: raise ValueError(f"Unknown news source: {source}") return crawler_class() def get_sync_db_session(): """获取同步数据库会话(Celery任务中使用)""" engine = create_engine(settings.SYNC_DATABASE_URL) return Session(engine) @celery_app.task(bind=True, name="app.tasks.crawl_tasks.realtime_crawl_task") def realtime_crawl_task(self, source: str = "sina", force_refresh: bool = False): """ 实时爬取任务 (Phase 2 升级版) 核心改进: 1. Redis 缓存检查(避免频繁爬取) 2. 智能时间过滤(基于配置的 NEWS_RETENTION_HOURS) 3. 只爬取最新一页 Args: source: 新闻源(sina, jrj等) force_refresh: 是否强制刷新(跳过缓存) """ db = get_sync_db_session() task_record = None cache_key = f"news:{source}:latest" cache_time_key = f"{cache_key}:timestamp" try: # ===== Phase 2.1: 检查 Redis 缓存 ===== if not force_refresh and redis_client.is_available(): cache_metadata = redis_client.get_cache_metadata(cache_key) if cache_metadata: age_seconds = cache_metadata['age_seconds'] # 根据不同源获取对应的爬取间隔 interval_map = { "sina": settings.CRAWL_INTERVAL_SINA, "tencent": settings.CRAWL_INTERVAL_TENCENT, "jwview": settings.CRAWL_INTERVAL_JWVIEW, "eeo": settings.CRAWL_INTERVAL_EEO, "caijing": settings.CRAWL_INTERVAL_CAIJING, "jingji21": settings.CRAWL_INTERVAL_JINGJI21, "nbd": 60, # 每日经济新闻 "yicai": 60, # 第一财经 "163": 60, # 网易财经 "eastmoney": 60, # 东方财富 } interval = interval_map.get(source, 60) # 默认60秒 # 如果缓存时间 < 爬取间隔,使用缓存 if age_seconds < interval: logger.info( f"[{source}] 使用缓存数据 (age: {age_seconds:.0f}s < {interval}s)" ) return { "status": "cached", "source": source, "cache_age": age_seconds, "message": f"缓存数据仍然有效,距上次爬取 {age_seconds:.0f} 秒" } # ===== 1. 创建任务记录 ===== task_record = CrawlTask( celery_task_id=self.request.id, mode=CrawlMode.REALTIME, status=TaskStatus.RUNNING, source=source, config={ "page_limit": 1, "retention_hours": settings.NEWS_RETENTION_HOURS, "force_refresh": force_refresh }, started_at=datetime.utcnow(), ) db.add(task_record) db.commit() db.refresh(task_record) logger.info(f"[Task {task_record.id}] 🚀 开始实时爬取: {source}") # ===== 2. 创建爬虫(使用工厂函数) ===== try: crawler = get_crawler_tool(source) except ValueError as e: logger.error(f"[Task {task_record.id}] ❌ {e}") raise # ===== 3. 执行爬取(只爬第一页) ===== start_time = datetime.utcnow() news_list = crawler.crawl(start_page=1, end_page=1) logger.info(f"[Task {task_record.id}] 📰 爬取到 {len(news_list)} 条新闻") # ===== Phase 2.2: 智能时间过滤 ===== cutoff_time = datetime.utcnow() - timedelta(hours=settings.NEWS_RETENTION_HOURS) recent_news = [ news for news in news_list if news.publish_time and news.publish_time > cutoff_time ] if news_list else [] logger.info( f"[Task {task_record.id}] ⏱️ 过滤后剩余 {len(recent_news)} 条新闻 " f"(保留 {settings.NEWS_RETENTION_HOURS} 小时内)" ) # ===== 4. 去重并保存 ===== saved_count = 0 duplicate_count = 0 for news_item in recent_news: # 检查URL是否已存在 existing = db.execute( select(News).where(News.url == news_item.url) ).scalar_one_or_none() if existing: duplicate_count += 1 logger.debug(f"[Task {task_record.id}] ⏭️ 跳过重复新闻: {news_item.title[:30]}...") continue # 创建新记录(清理 NUL 字符,PostgreSQL 不允许存储) news = News( title=clean_text_for_db(news_item.title), content=clean_text_for_db(news_item.content), raw_html=clean_text_for_db(news_item.raw_html), # 保存原始 HTML url=clean_text_for_db(news_item.url), source=clean_text_for_db(news_item.source), publish_time=news_item.publish_time, author=clean_text_for_db(news_item.author), keywords=news_item.keywords, stock_codes=news_item.stock_codes, ) db.add(news) saved_count += 1 db.commit() logger.info( f"[Task {task_record.id}] 💾 保存 {saved_count} 条新新闻 " f"(重复: {duplicate_count})" ) # ===== Phase 2.3: 更新 Redis 缓存 ===== if redis_client.is_available() and recent_news: # 将新闻列表序列化后存入缓存 cache_data = [ { "title": n.title, "url": n.url, "publish_time": n.publish_time.isoformat() if n.publish_time else None, "source": n.source, } for n in recent_news ] success = redis_client.set_with_metadata( cache_key, cache_data, ttl=settings.CACHE_TTL ) if success: logger.info(f"[Task {task_record.id}] 💾 Redis 缓存已更新 (TTL: {settings.CACHE_TTL}s)") # ===== 5. 更新任务状态 ===== end_time = datetime.utcnow() execution_time = (end_time - start_time).total_seconds() task_record.status = TaskStatus.COMPLETED task_record.completed_at = end_time task_record.execution_time = execution_time task_record.crawled_count = len(recent_news) task_record.saved_count = saved_count task_record.result = { "total_crawled": len(news_list), "filtered": len(recent_news), "saved": saved_count, "duplicates": duplicate_count, "retention_hours": settings.NEWS_RETENTION_HOURS, } db.commit() logger.info( f"[Task {task_record.id}] ✅ 完成! " f"爬取: {len(news_list)} → 过滤: {len(recent_news)} → 保存: {saved_count}, " f"耗时: {execution_time:.2f}s" ) return { "task_id": task_record.id, "status": "completed", "source": source, "crawled": len(news_list), "filtered": len(recent_news), "saved": saved_count, "duplicates": duplicate_count, "execution_time": execution_time, "timestamp": datetime.utcnow().isoformat(), } except Exception as e: logger.error(f"[Task {task_record.id if task_record else 'unknown'}] 爬取失败: {e}", exc_info=True) if task_record: task_record.status = TaskStatus.FAILED task_record.completed_at = datetime.utcnow() task_record.error_message = str(e)[:1000] db.commit() # 重新抛出异常,让 Celery 记录 raise finally: db.close() @celery_app.task(bind=True, name="app.tasks.crawl_tasks.cold_start_crawl_task") def cold_start_crawl_task( self, source: str = "sina", start_page: int = 1, end_page: int = 50, ): """ 冷启动批量爬取任务 Args: source: 新闻源 start_page: 起始页 end_page: 结束页 """ db = get_sync_db_session() task_record = None try: # 1. 创建任务记录 task_record = CrawlTask( celery_task_id=self.request.id, mode=CrawlMode.COLD_START, status=TaskStatus.RUNNING, source=source, config={ "start_page": start_page, "end_page": end_page, }, total_pages=end_page - start_page + 1, started_at=datetime.utcnow(), ) db.add(task_record) db.commit() db.refresh(task_record) logger.info(f"[Task {task_record.id}] 开始冷启动爬取: {source}, 页码 {start_page}-{end_page}") # 2. 创建爬虫 if source == "sina": crawler = SinaCrawlerTool() else: raise ValueError(f"不支持的新闻源: {source}") # 3. 分页爬取 start_time = datetime.utcnow() total_crawled = 0 total_saved = 0 for page in range(start_page, end_page + 1): try: # 更新进度 task_record.current_page = page task_record.progress = { "current_page": page, "total_pages": task_record.total_pages, "percentage": round((page - start_page + 1) / task_record.total_pages * 100, 2), } db.commit() # 爬取单页 news_list = crawler.crawl(start_page=page, end_page=page) total_crawled += len(news_list) # 保存新闻 page_saved = 0 for news_item in news_list: existing = db.execute( select(News).where(News.url == news_item.url) ).scalar_one_or_none() if not existing: # 清理 NUL 字符,PostgreSQL 不允许存储 news = News( title=clean_text_for_db(news_item.title), content=clean_text_for_db(news_item.content), raw_html=clean_text_for_db(news_item.raw_html), # 保存原始 HTML url=clean_text_for_db(news_item.url), source=clean_text_for_db(news_item.source), publish_time=news_item.publish_time, author=clean_text_for_db(news_item.author), keywords=news_item.keywords, stock_codes=news_item.stock_codes, ) db.add(news) page_saved += 1 db.commit() total_saved += page_saved logger.info( f"[Task {task_record.id}] 页 {page}/{end_page}: " f"爬取 {len(news_list)} 条, 保存 {page_saved} 条" ) except Exception as e: logger.error(f"[Task {task_record.id}] 页 {page} 爬取失败: {e}") continue # 4. 更新任务状态 end_time = datetime.utcnow() execution_time = (end_time - start_time).total_seconds() task_record.status = TaskStatus.COMPLETED task_record.completed_at = end_time task_record.execution_time = execution_time task_record.crawled_count = total_crawled task_record.saved_count = total_saved task_record.result = { "pages_crawled": end_page - start_page + 1, "total_crawled": total_crawled, "total_saved": total_saved, "duplicates": total_crawled - total_saved, } db.commit() logger.info( f"[Task {task_record.id}] 冷启动完成! " f"页数: {end_page - start_page + 1}, 爬取: {total_crawled}, 保存: {total_saved}, " f"耗时: {execution_time:.2f}s" ) return { "task_id": task_record.id, "status": "completed", "crawled": total_crawled, "saved": total_saved, "execution_time": execution_time, } except Exception as e: logger.error(f"[Task {task_record.id if task_record else 'unknown'}] 冷启动失败: {e}", exc_info=True) if task_record: task_record.status = TaskStatus.FAILED task_record.completed_at = datetime.utcnow() task_record.error_message = str(e)[:1000] db.commit() raise finally: db.close() @celery_app.task(bind=True, name="app.tasks.crawl_tasks.targeted_stock_crawl_task") def targeted_stock_crawl_task( self, stock_code: str, stock_name: str, days: int = 30, task_record_id: int = None ): """ 定向爬取某只股票的相关新闻(精简版 - 只使用 BochaAI) 数据来源:BochaAI 搜索引擎 API 图谱构建逻辑: - 有历史新闻数据 → 先构建/使用图谱 → 基于图谱扩展关键词搜索 - 无历史新闻数据 → 先用 BochaAI 爬取 → 爬取完成后异步构建图谱 Args: stock_code: 股票代码(如 SH600519) stock_name: 股票名称(如 贵州茅台) days: 搜索时间范围(天),默认30天 task_record_id: 数据库中的任务记录ID(如果已创建) """ db = get_sync_db_session() task_record = None try: # 标准化股票代码 code = stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): pure_code = code[2:] else: pure_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" # 1. 获取或创建任务记录 if task_record_id: task_record = db.query(CrawlTask).filter(CrawlTask.id == task_record_id).first() if task_record: task_record.status = TaskStatus.RUNNING task_record.started_at = datetime.utcnow() db.commit() db.refresh(task_record) else: logger.warning(f"Task record {task_record_id} not found, creating new one") task_record_id = None if not task_record: task_record = CrawlTask( celery_task_id=self.request.id, mode=CrawlMode.TARGETED, status=TaskStatus.RUNNING, source="targeted", config={ "stock_code": code, "stock_name": stock_name, "days": days, }, started_at=datetime.utcnow(), ) db.add(task_record) db.commit() db.refresh(task_record) logger.info(f"[Task {task_record.id}] 🎯 开始定向爬取: {stock_name}({code}), 时间范围: {days}天") start_time = datetime.utcnow() all_news = [] search_results = [] # ======================================== # 【核心逻辑】先用 akshare 获取股票基础信息,构建简单图谱 # ======================================== task_record.progress = {"current": 5, "total": 100, "message": "获取股票基础信息..."} db.commit() from ..knowledge.knowledge_extractor import AkshareKnowledgeExtractor # 1. 从 akshare 获取公司基础信息 logger.info(f"[Task {task_record.id}] 🔍 从 akshare 获取 {stock_name}({pure_code}) 基础信息...") akshare_info = None try: akshare_info = AkshareKnowledgeExtractor.extract_company_info(pure_code) if akshare_info: logger.info(f"[Task {task_record.id}] ✅ akshare 返回: 行业={akshare_info.get('industry')}, 主营={akshare_info.get('main_business', '')[:50]}...") else: logger.warning(f"[Task {task_record.id}] ⚠️ akshare 未返回数据,将使用股票名称生成关键词") except Exception as e: logger.warning(f"[Task {task_record.id}] ⚠️ akshare 查询失败: {e},将使用股票名称生成关键词") # 2. 构建简单图谱并生成搜索关键词 task_record.progress = {"current": 10, "total": 100, "message": "构建知识图谱..."} db.commit() simple_graph = AkshareKnowledgeExtractor.build_simple_graph_from_info( stock_code=code, stock_name=stock_name, akshare_info=akshare_info ) # 获取分层关键词 core_keywords = simple_graph.get("core_keywords", [stock_name]) extension_keywords = simple_graph.get("extension_keywords", []) logger.info( f"[Task {task_record.id}] 📋 关键词分层: " f"核心={len(core_keywords)}个{core_keywords[:4]}, " f"扩展={len(extension_keywords)}个{extension_keywords[:4]}" ) logger.info(f"[Task {task_record.id}] 🔑 完整核心关键词列表: {core_keywords}") logger.info(f"[Task {task_record.id}] 🔑 完整扩展关键词列表: {extension_keywords}") # ======================================== # 【搜索阶段】使用组合关键词调用 BochaAI 搜索 # ======================================== task_record.progress = {"current": 20, "total": 100, "message": "BochaAI 组合搜索中..."} db.commit() if not bochaai_search.is_available(): logger.error(f"[Task {task_record.id}] ❌ BochaAI API Key 未配置,无法执行搜索") raise ValueError("BochaAI API Key 未配置") # ======================================== # 【组合搜索策略】 # 1. 必须搜索:核心关键词(公司名、代码) # 2. 可选组合:核心词 + 扩展词(行业、业务、人名) # ======================================== all_search_results = [] search_queries = [] # 策略1:核心关键词单独搜索(取前3个最重要的) for core_kw in core_keywords[:3]: # 跳过纯数字代码(单独搜会很泛) if not (core_kw.isdigit() or core_kw.startswith("SH") or core_kw.startswith("SZ")): search_queries.append(core_kw) # 策略2:核心词 + 扩展词组合搜索(最多3个组合) if extension_keywords: # 取最主要的核心词(通常是股票简称) main_core = core_keywords[0] if core_keywords else stock_name for ext_kw in extension_keywords[:3]: # 组合搜索:如 "*ST国华 软件开发" combined_query = f"{main_core} {ext_kw}" search_queries.append(combined_query) # 限制总查询数(避免过多请求) search_queries = search_queries[:5] logger.info(f"[Task {task_record.id}] 🚀 生成 {len(search_queries)} 个搜索查询:") for i, q in enumerate(search_queries): logger.info(f" [{i+1}] {q}") # 执行搜索 for query in search_queries: try: logger.info(f"[Task {task_record.id}] 🔍 搜索: '{query}'") kw_results = bochaai_search.search_stock_news( stock_name=query, # 使用组合查询 stock_code=pure_code, days=days, count=50, # 每个查询最多 50 条 max_age_days=365 ) logger.info(f"[Task {task_record.id}] 📰 查询 '{query}' 搜索到 {len(kw_results)} 条结果") all_search_results.extend(kw_results) except Exception as e: logger.warning(f"[Task {task_record.id}] ⚠️ 查询 '{query}' 搜索失败: {e}") # 去重(按 URL) seen_urls = set() search_results = [] for r in all_search_results: if r.url not in seen_urls: seen_urls.add(r.url) search_results.append(r) logger.info(f"[Task {task_record.id}] 📊 合并 {len(all_search_results)} 条,去重后 {len(search_results)} 条") # ======================================== # 【处理阶段】转换搜索结果为 NewsItem # ======================================== task_record.progress = {"current": 50, "total": 100, "message": "处理搜索结果..."} db.commit() bochaai_matched = 0 bochaai_filtered = 0 # 检查是否应该启用宽松过滤模式 # 如果核心关键词太少(<= 2个),或者搜索结果很少(<10条),使用宽松过滤 use_relaxed_filter = len(core_keywords) <= 2 or len(search_results) < 10 if use_relaxed_filter: logger.info(f"[Task {task_record.id}] 🔓 启用宽松过滤模式(核心词={len(core_keywords)}个, 结果={len(search_results)}条)") # 打印 BochaAI 返回的前 10 条数据用于调试 logger.info(f"[Task {task_record.id}] 📋 BochaAI 返回数据预览 (前10条):") for i, r in enumerate(search_results[:10]): logger.info(f" [{i+1}] 标题: {r.title[:60]}...") logger.info(f" 来源: {r.site_name}, 日期: {r.date_published}") logger.info(f" URL: {r.url[:80]}...") for idx, result in enumerate(search_results): # 解析发布时间 publish_time = None if result.date_published: try: publish_time = datetime.fromisoformat( result.date_published.replace('Z', '+00:00') ) except (ValueError, AttributeError): pass # 【注意】不再二次爬取完整内容,直接使用摘要(提升速度) full_content = result.snippet # 相关性过滤:必须包含至少一个核心关键词 text_to_check = result.title + " " + result.snippet text_to_check_lower = text_to_check.lower() # 检查是否匹配任何核心关键词 is_match = False matched_keyword = None for kw in core_keywords: if not kw or len(kw) < 2: continue kw_lower = kw.lower() # 宽松匹配策略: # 1. 完整匹配(大小写不敏感) if kw in text_to_check or kw_lower in text_to_check_lower: is_match = True matched_keyword = kw break # 2. 去除特殊字符后匹配(处理 *ST 等情况) import re kw_clean = re.sub(r'[*\s]', '', kw) if len(kw_clean) >= 2 and kw_clean.lower() in text_to_check_lower: is_match = True matched_keyword = f"{kw} (cleaned: {kw_clean})" break if not is_match: # 宽松模式下,如果标题包含股票代码数字,也认为相关 if use_relaxed_filter and pure_code in text_to_check: is_match = True matched_keyword = f"{pure_code} (relaxed mode)" logger.debug(f"[Task {task_record.id}] 🔓 宽松模式匹配: {result.title[:40]}... (包含代码)") else: bochaai_filtered += 1 # 打印前 5 条被过滤的原因 if bochaai_filtered <= 5: logger.info(f"[Task {task_record.id}] ❌ 过滤[{idx+1}]: 不包含核心关键词") logger.info(f" 标题: {result.title[:80]}") logger.info(f" 摘要: {result.snippet[:100]}...") logger.info(f" 核心词: {core_keywords}") continue # 如果宽松模式跳过了上面的 continue,需要确保 is_match 为 True if not is_match: continue logger.debug(f"[Task {task_record.id}] ✅ 匹配核心词 '{matched_keyword}': {result.title[:40]}...") bochaai_matched += 1 # 尝试爬取页面获取完整 HTML(只对前 15 条匹配结果爬取,避免任务太慢) raw_html = None crawled_content = None if bochaai_matched <= 15: try: from ..tools.interactive_crawler import InteractiveCrawler page_crawler = InteractiveCrawler(timeout=10) page_data = page_crawler.crawl_page(result.url) if page_data: raw_html = page_data.get('html') crawled_content = page_data.get('content') or page_data.get('text') logger.debug(f"[Task {task_record.id}] 📄 爬取成功: {result.url[:50]}... | HTML {len(raw_html) if raw_html else 0}字符") except Exception as e: logger.debug(f"[Task {task_record.id}] ⚠️ 爬取页面失败 {result.url[:50]}...: {e}") # 优先使用爬取的完整内容 final_content = crawled_content if crawled_content and len(crawled_content) > len(full_content) else full_content news_item = NewsItem( title=result.title, content=final_content, url=result.url, source=result.site_name or "web_search", publish_time=publish_time, stock_codes=[pure_code, code], raw_html=raw_html, ) all_news.append(news_item) # 每处理 20 条更新一次进度 if (idx + 1) % 20 == 0: progress_pct = 50 + int((idx + 1) / len(search_results) * 30) task_record.progress = {"current": progress_pct, "total": 100, "message": f"处理中 {idx+1}/{len(search_results)}..."} db.commit() logger.info(f"[Task {task_record.id}] 🔍 搜索到 {len(search_results)} 条,匹配 {bochaai_matched} 条,过滤 {bochaai_filtered} 条") # ======================================== # 【交互式爬虫补充】如果相关性匹配结果太少,使用交互式爬虫补充 # ======================================== if bochaai_matched < 5: # 匹配结果太少时启动交互式爬虫 logger.info(f"[Task {task_record.id}] 🌐 相关结果较少({bochaai_matched}条),启用交互式爬虫补充...") try: from ..tools.interactive_crawler import create_interactive_crawler # 使用核心关键词进行搜索 # 取最主要的核心词(通常是股票简称) interactive_query = core_keywords[0] if core_keywords else stock_name logger.info(f"[Task {task_record.id}] 🔍 使用交互式爬虫搜索: '{interactive_query}'") crawler = create_interactive_crawler(headless=True) # 使用百度资讯搜索(专门获取新闻,比 Bing 更稳定) interactive_results = crawler.interactive_search( interactive_query, engines=["baidu_news", "sogou"], # 百度资讯 + 搜狗 num_results=15, search_type="news" # 新闻搜索 ) logger.info(f"[Task {task_record.id}] ✅ 交互式爬虫返回 {len(interactive_results)} 条结果") # 现在使用 news.baidu.com 入口,返回的是真实的第三方链接 # 可以安全爬取这些页面获取完整内容(除了需要 JS 渲染的网站) # 需要 JS 渲染的网站列表(无法用 requests 爬取) JS_RENDERED_SITES = [ 'baijiahao.baidu.com', # 百家号需要 JS 渲染 'mbd.baidu.com', # 百度移动版 'xueqiu.com', # 雪球 'mp.weixin.qq.com', # 微信公众号 ] for result in interactive_results[:10]: # 最多取 10 条 url = result.get('url', '') title = result.get('title', '') snippet = result.get('snippet', '') # 跳过无效结果 if not url or not title: continue # 跳过已存在的 URL if url in {item.url for item in all_news}: continue # 跳过百度跳转链接 if 'baidu.com/link?' in url: logger.debug(f"跳过百度跳转链接: {url}") continue # 检查是否是需要 JS 渲染的网站 needs_js_render = any(site in url for site in JS_RENDERED_SITES) page_content = "" raw_html = None if needs_js_render: # JS 渲染网站:直接使用搜索结果的摘要 logger.debug(f" ⚠️ JS渲染网站,使用搜索摘要: {url[:50]}...") page_content = snippet if snippet else title else: # 普通网站:尝试爬取页面获取完整内容 try: page_data = crawler.crawl_page(url) if page_data: page_content = page_data.get('text', '') or page_data.get('content', '') raw_html = page_data.get('html', '') # 如果爬取的标题更完整,使用爬取的标题 if page_data.get('title') and len(page_data.get('title', '')) > len(title): title = page_data.get('title', title) logger.debug(f" ✅ 成功爬取页面: {title[:30]}...") except Exception as e: logger.debug(f" ⚠️ 爬取页面失败 {url}: {e}") # 如果爬取失败,使用搜索结果的摘要 if not page_content: page_content = snippet if snippet else title news_item = NewsItem( title=title, content=page_content, url=url, source=result.get('news_source') or result.get('source', 'baidu_news'), publish_time=None, # 交互爬虫没有发布时间 stock_codes=[pure_code, code], raw_html=raw_html, # JS 渲染网站不保存乱码 HTML ) all_news.append(news_item) bochaai_matched += 1 logger.info(f"[Task {task_record.id}] 📊 交互式爬虫补充后总计: {bochaai_matched} 条匹配结果") except ImportError: logger.warning(f"[Task {task_record.id}] ⚠️ 交互式爬虫模块不可用,跳过补充搜索") except Exception as e: logger.error(f"[Task {task_record.id}] ❌ 交互式爬虫补充失败: {e}", exc_info=True) # ======================================== # 【保存阶段】去重并保存新闻 # ======================================== task_record.progress = {"current": 80, "total": 100, "message": "保存新闻..."} db.commit() saved_count = 0 duplicate_count = 0 logger.info(f"[Task {task_record.id}] 💾 开始保存 {len(all_news)} 条新闻...") for news_item in all_news: # 检查URL是否已存在 existing = db.execute( select(News).where(News.url == news_item.url) ).scalar_one_or_none() if existing: duplicate_count += 1 # 如果已存在但没有关联这个股票,更新关联 if existing.stock_codes is None: existing.stock_codes = [] if pure_code not in existing.stock_codes: existing.stock_codes = existing.stock_codes + [pure_code] db.commit() continue # 创建新记录(清理 NUL 字符,PostgreSQL 不允许存储) news = News( title=clean_text_for_db(news_item.title), content=clean_text_for_db(news_item.content), raw_html=clean_text_for_db(news_item.raw_html), # 保存原始 HTML url=clean_text_for_db(news_item.url), source=clean_text_for_db(news_item.source), publish_time=news_item.publish_time, author=clean_text_for_db(news_item.author), keywords=news_item.keywords, stock_codes=news_item.stock_codes or [pure_code, code], ) db.add(news) saved_count += 1 db.commit() logger.info( f"[Task {task_record.id}] 💾 保存 {saved_count} 条新闻 " f"(重复: {duplicate_count})" ) # ======================================== # 【图谱更新阶段】异步构建完整图谱(基于 Neo4j) # ======================================== task_record.progress = {"current": 90, "total": 100, "message": "触发异步图谱构建..."} db.commit() if saved_count > 0: # 有新闻保存成功后,触发异步图谱构建任务 logger.info(f"[Task {task_record.id}] 🧠 触发异步图谱构建任务...") try: build_knowledge_graph_task.delay(code, stock_name) logger.info(f"[Task {task_record.id}] ✅ 异步图谱构建任务已触发") except Exception as e: logger.error(f"[Task {task_record.id}] ❌ 触发异步图谱构建失败: {e}") # ======================================== # 【完成阶段】更新任务状态 # ======================================== end_time = datetime.utcnow() execution_time = (end_time - start_time).total_seconds() task_record.status = TaskStatus.COMPLETED task_record.completed_at = end_time task_record.execution_time = execution_time task_record.crawled_count = len(all_news) task_record.saved_count = saved_count task_record.result = { "stock_code": code, "stock_name": stock_name, "total_found": len(all_news), "saved": saved_count, "duplicates": duplicate_count, "akshare_info": bool(akshare_info), # 是否获取到 akshare 数据 "core_keywords": core_keywords[:5], # 核心关键词 "search_queries": search_queries, # 实际搜索的查询 "sources": { "bochaai": len(search_results), } } task_record.progress = { "current": 100, "total": 100, "message": f"完成!新增 {saved_count} 条新闻" } db.commit() logger.info( f"[Task {task_record.id}] ✅ 定向爬取完成! " f"股票: {stock_name}({code}), 找到: {len(all_news)}, 保存: {saved_count}, " f"耗时: {execution_time:.2f}s" ) return { "task_id": task_record.id, "status": "completed", "stock_code": code, "stock_name": stock_name, "crawled": len(all_news), "saved": saved_count, "duplicates": duplicate_count, "execution_time": execution_time, "timestamp": datetime.utcnow().isoformat(), } except Exception as e: logger.error(f"[Task {task_record.id if task_record else 'unknown'}] 定向爬取失败: {e}", exc_info=True) if task_record: task_record.status = TaskStatus.FAILED task_record.completed_at = datetime.utcnow() task_record.error_message = str(e)[:1000] task_record.progress = { "current": 0, "total": 100, "message": f"失败: {str(e)[:100]}" } db.commit() raise finally: db.close() @celery_app.task(bind=True, name="app.tasks.crawl_tasks.build_knowledge_graph_task") def build_knowledge_graph_task(self, stock_code: str, stock_name: str): """ 异步构建知识图谱任务 在无历史新闻数据的股票首次爬取完成后触发。 从数据库中的新闻数据 + akshare 基础信息构建知识图谱。 Args: stock_code: 股票代码(如 SH600519) stock_name: 股票名称(如 贵州茅台) """ db = get_sync_db_session() try: code = stock_code.upper() if code.startswith("SH") or code.startswith("SZ"): pure_code = code[2:] else: pure_code = code code = f"SH{code}" if code.startswith("6") else f"SZ{code}" logger.info(f"[GraphTask] 🏗️ 开始异步构建知识图谱: {stock_name}({code})") from ..knowledge.graph_service import get_graph_service from ..knowledge.knowledge_extractor import ( create_knowledge_extractor, AkshareKnowledgeExtractor ) graph_service = get_graph_service() # 1. 检查图谱是否已存在(避免重复构建) existing_graph = graph_service.get_company_graph(code) if existing_graph: logger.info(f"[GraphTask] ✅ 图谱已存在,跳过构建") return {"status": "skipped", "reason": "graph_exists"} # 2. 从 akshare 获取基础公司信息 akshare_info = AkshareKnowledgeExtractor.extract_company_info(code) if akshare_info: extractor = create_knowledge_extractor() base_graph = asyncio.run( extractor.extract_from_akshare(code, stock_name, akshare_info) ) graph_service.build_company_graph(base_graph) logger.info(f"[GraphTask] ✅ 基础图谱构建完成") else: logger.warning(f"[GraphTask] ⚠️ akshare 未返回数据") # 3. 从数据库新闻中提取信息更新图谱 recent_news = db.execute( text(""" SELECT title, content FROM news WHERE stock_codes @> ARRAY[:code]::varchar[] ORDER BY publish_time DESC LIMIT 50 """).bindparams(code=pure_code) ).fetchall() if recent_news: news_data = [{"title": n[0], "content": n[1]} for n in recent_news] extractor = create_knowledge_extractor() extracted_info = asyncio.run( extractor.extract_from_news(code, stock_name, news_data) ) if any(extracted_info.values()): graph_service.update_from_news(code, "", extracted_info) logger.info(f"[GraphTask] ✅ 从新闻更新图谱完成") logger.info(f"[GraphTask] ✅ 知识图谱构建完成: {stock_name}({code})") return { "status": "completed", "stock_code": code, "stock_name": stock_name, "news_count": len(recent_news) if recent_news else 0, } except Exception as e: logger.error(f"[GraphTask] ❌ 知识图谱构建失败: {e}", exc_info=True) return {"status": "failed", "error": str(e)} finally: db.close() ================================================ FILE: backend/app/tools/__init__.py ================================================ """ 工具模块 """ from .crawler_base import BaseCrawler, NewsItem from .sina_crawler import SinaCrawlerTool, create_sina_crawler from .tencent_crawler import TencentCrawlerTool from .jwview_crawler import JwviewCrawlerTool from .eeo_crawler import EeoCrawlerTool from .caijing_crawler import CaijingCrawlerTool from .jingji21_crawler import Jingji21CrawlerTool from .nbd_crawler import NbdCrawlerTool from .yicai_crawler import YicaiCrawlerTool from .netease163_crawler import Netease163CrawlerTool from .eastmoney_crawler import EastmoneyCrawlerTool from .text_cleaner import TextCleanerTool, create_text_cleaner from .bochaai_search import BochaAISearchTool, bochaai_search, SearchResult __all__ = [ "BaseCrawler", "NewsItem", "SinaCrawlerTool", "create_sina_crawler", "TencentCrawlerTool", "JwviewCrawlerTool", "EeoCrawlerTool", "CaijingCrawlerTool", "Jingji21CrawlerTool", "NbdCrawlerTool", "YicaiCrawlerTool", "Netease163CrawlerTool", "EastmoneyCrawlerTool", "TextCleanerTool", "create_text_cleaner", "BochaAISearchTool", "bochaai_search", "SearchResult", ] ================================================ FILE: backend/app/tools/bochaai_search.py ================================================ """ BochaAI Web Search Tool 用于定向搜索股票相关新闻 """ import json import logging import urllib.request import urllib.error from typing import List, Dict, Any, Optional from datetime import datetime from dataclasses import dataclass from ..core.config import settings logger = logging.getLogger(__name__) @dataclass class SearchResult: """搜索结果数据类""" title: str url: str snippet: str site_name: Optional[str] = None date_published: Optional[str] = None class BochaAISearchTool: """ BochaAI Web Search 工具 用于搜索股票相关新闻 """ def __init__(self, api_key: Optional[str] = None, endpoint: Optional[str] = None): """ 初始化 BochaAI 搜索工具 Args: api_key: BochaAI API Key(如果不提供,从配置中获取) endpoint: API 端点(默认使用配置中的端点) """ self.api_key = api_key or settings.BOCHAAI_API_KEY self.endpoint = endpoint or settings.BOCHAAI_ENDPOINT if not self.api_key: logger.warning( "BochaAI API Key 未配置,搜索功能将不可用。\n" "请在 .env 文件中设置 BOCHAAI_API_KEY=your_api_key" ) def is_available(self) -> bool: """检查搜索功能是否可用""" return bool(self.api_key) def search( self, query: str, freshness: str = "noLimit", count: int = 10, offset: int = 0, include_sites: Optional[str] = None, exclude_sites: Optional[str] = None, ) -> List[SearchResult]: """ 执行 Web 搜索 Args: query: 搜索查询字符串 freshness: 时间范围(noLimit, day, week, month) count: 返回结果数量(1-50,单次最大50条) offset: 结果偏移量(用于分页) include_sites: 限定搜索的网站(逗号分隔) exclude_sites: 排除的网站(逗号分隔) Returns: 搜索结果列表 """ if not self.is_available(): logger.warning("BochaAI API Key 未配置,跳过搜索") return [] try: # 构建请求数据 request_data = { "query": query, "freshness": freshness, "summary": False, "count": min(max(count, 1), 50) } # 添加 offset 参数进行分页 if offset > 0: request_data["offset"] = offset if include_sites: request_data["include"] = include_sites if exclude_sites: request_data["exclude"] = exclude_sites # 创建请求 req = urllib.request.Request( self.endpoint, data=json.dumps(request_data).encode('utf-8'), headers={ 'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json', 'User-Agent': 'FinnewsHunter-BochaAI-Search/1.0' } ) # 发送请求 with urllib.request.urlopen(req, timeout=30) as response: data = response.read().decode('utf-8') result = json.loads(data) # 解析结果 results = [] if 'data' in result: data = result['data'] if 'webPages' in data and data['webPages'] and 'value' in data['webPages']: for item in data['webPages']['value']: search_result = SearchResult( title=item.get('name', '无标题'), url=item.get('url', ''), snippet=item.get('snippet', ''), site_name=item.get('siteName', ''), date_published=item.get('datePublished', '') ) results.append(search_result) logger.info(f"BochaAI 搜索完成: query='{query}', offset={offset}, 结果数={len(results)}") return results except urllib.error.HTTPError as e: error_msg = f"BochaAI API HTTP 错误: {e.code} - {e.reason}" if e.code == 401: error_msg += " (请检查 BOCHAAI_API_KEY 是否正确)" elif e.code == 429: error_msg += " (请求过于频繁)" logger.error(error_msg) return [] except urllib.error.URLError as e: logger.error(f"BochaAI 网络错误: {e.reason}") return [] except json.JSONDecodeError as e: logger.error(f"BochaAI 响应解析失败: {e}") return [] except Exception as e: logger.error(f"BochaAI 搜索失败: {e}") return [] def search_stock_news( self, stock_name: str, stock_code: Optional[str] = None, days: int = 30, count: int = 100, max_age_days: int = 365, ) -> List[SearchResult]: """ 搜索股票相关新闻 Args: stock_name: 股票名称(如"贵州茅台") stock_code: 股票代码(可选,如"600519") days: 搜索时间范围(天),用于API freshness参数 count: 返回结果数量(支持超过50条,会自动分页请求) max_age_days: 最大新闻年龄(天),默认365天(1年),超过此时间的新闻将被过滤 Returns: 搜索结果列表(按时间从新到旧排序,只返回最近max_age_days天内的新闻) """ # 构建搜索查询 - 简洁明确,添加"最新"关键词优先获取新内容 query = f"{stock_name} 最新" # BochaAI API 支持的 freshness 参数值: # - noLimit: 不限制 # - oneDay: 一天内 # - oneWeek: 一周内 # - oneMonth: 一月内 # 注意:不支持 "year"、"day"、"week" 等其他值! # 根据请求天数确定 freshness 参数 if days <= 1: freshness = "oneDay" elif days <= 7: freshness = "oneWeek" elif days <= 30: freshness = "oneMonth" else: freshness = "noLimit" # 超过30天用 noLimit,本地再过滤 # 财经网站列表(用于优先搜索) finance_sites = ( "finance.sina.com.cn," "stock.eastmoney.com," "finance.qq.com," "money.163.com," "caijing.com.cn," "yicai.com," "nbd.com.cn," "21jingji.com," "eeo.com.cn," "chinanews.com.cn" ) # 计算截止时间(半年前) from datetime import timedelta cutoff_date = datetime.now() - timedelta(days=max_age_days) all_results = [] offset = 0 batch_size = 50 # API单次最大返回数 max_requests = 5 # 最多请求5次,防止无限循环 request_count = 0 logger.info(f"BochaAI 开始搜索股票新闻: {stock_name}, 目标数量={count}, 截止日期={cutoff_date.strftime('%Y-%m-%d')}") while len(all_results) < count and request_count < max_requests: batch_results = self.search( query=query, freshness=freshness, count=batch_size, offset=offset, include_sites=finance_sites ) if not batch_results: logger.info(f"BochaAI 第{request_count+1}次请求未返回结果,停止分页") break # 时间过滤:保留有日期且在范围内的新闻,以及无日期但可能相关的新闻 for result in batch_results: # 如果有发布日期,检查是否在时间范围内 if result.date_published: try: # 尝试解析发布时间 pub_date = datetime.fromisoformat( result.date_published.replace('Z', '+00:00') ) # 转换为无时区的时间进行比较 if pub_date.tzinfo: pub_date = pub_date.replace(tzinfo=None) # 检查是否在指定时间范围内 if pub_date < cutoff_date: logger.debug(f"过滤超过{max_age_days}天的新闻: {result.title[:30]}... ({result.date_published})") continue except (ValueError, AttributeError) as e: # 日期解析失败,但仍然保留(可能是新闻) logger.debug(f"无法解析日期,但仍保留: {result.title[:30]}...") else: # 无日期的新闻也保留(可能是相关新闻) logger.debug(f"无日期新闻,保留: {result.title[:30]}...") # 添加到结果中 all_results.append(result) if len(all_results) >= count: break offset += batch_size request_count += 1 logger.info(f"BochaAI 第{request_count}次请求完成,当前累计 {len(all_results)} 条有效结果") # 按发布时间排序(从新到旧) def parse_date(r): if r.date_published: try: dt = datetime.fromisoformat(r.date_published.replace('Z', '+00:00')) if dt.tzinfo: dt = dt.replace(tzinfo=None) return dt except (ValueError, AttributeError): pass return datetime.min # 无法解析的日期排在最后 all_results.sort(key=parse_date, reverse=True) logger.info(f"BochaAI 搜索股票新闻完成: {stock_name}, 返回 {len(all_results)} 条结果 (共请求{request_count}次, 仅保留最近{max_age_days}天即{max_age_days//30}个月内)") return all_results[:count] # 确保不超过请求数量 # 全局实例 bochaai_search = BochaAISearchTool() ================================================ FILE: backend/app/tools/caijing_crawler.py ================================================ """ 财经网爬虫工具 目标URL: https://www.caijing.com.cn/ (股市栏目) """ import re import logging from typing import List, Optional from datetime import datetime, timedelta from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class CaijingCrawlerTool(BaseCrawler): """ 财经网爬虫 主要爬取股市相关新闻 """ BASE_URL = "https://finance.caijing.com.cn/" # 股市栏目URL STOCK_URL = "https://finance.caijing.com.cn/" SOURCE_NAME = "caijing" def __init__(self): super().__init__( name="caijing_crawler", description="Crawl financial news from Caijing (caijing.com.cn)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取财经网新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled Caijing, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling Caijing: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: # 尝试爬取股市栏目或主页 try: response = self._fetch_page(self.STOCK_URL) except: response = self._fetch_page(self.BASE_URL) # 财经网编码处理 if response.encoding == 'ISO-8859-1' or not response.encoding: response.encoding = 'utf-8' soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) # 财经网新闻URL模式(扩展更多模式) caijing_patterns = [ r'/\d{4}/', # 日期路径 /2024/ '/article/', # 文章 '.shtml', # 静态HTML '/finance/', # 财经频道 '/stock/', # 股票频道 ] for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 检查是否匹配财经网URL模式 is_caijing_url = False # 方式1: 检查URL模式 for pattern in caijing_patterns: if re.search(pattern, href): is_caijing_url = True break # 方式2: 检查是否包含caijing.com.cn域名 if 'caijing.com.cn' in href or 'finance.caijing.com.cn' in href: is_caijing_url = True # 方式3: 检查链接的class或data属性 if not is_caijing_url: link_class = link.get('class', []) if isinstance(link_class, list): link_class_str = ' '.join(link_class) else: link_class_str = str(link_class) if any(kw in link_class_str.lower() for kw in ['news', 'article', 'item', 'title', 'list']): if href.startswith('/') or 'caijing.com.cn' in href: is_caijing_url = True if is_caijing_url and title and len(title.strip()) > 5: # 规范化 URL,优先 https,避免重复前缀 if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.caijing.com.cn' + href elif href.startswith('http://'): href = href.replace('http://', 'https://', 1) elif not href.startswith('http'): href = 'https://www.caijing.com.cn/' + href.lstrip('/') # 过滤掉明显不是新闻的链接 if any(skip in href.lower() for skip in ['javascript:', 'mailto:', '#', 'void(0)', '/tag/', '/author/', '/user/']): continue if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title.strip()}) logger.debug(f"Caijing: Found {len(news_links)} potential news links") return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'article-content'}, {'class': 'main_txt'}, {'class': 'content'}, {'id': 'the_content'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/crawler_base.py ================================================ """ 爬虫基类 符合 AgenticX BaseTool 协议 """ import time import logging from typing import List, Dict, Any, Optional from dataclasses import dataclass from datetime import datetime import requests from bs4 import BeautifulSoup import requests.exceptions from agenticx import BaseTool from agenticx.core import ToolMetadata, ToolCategory from ..core.config import settings logger = logging.getLogger(__name__) @dataclass class NewsItem: """新闻数据项""" title: str content: str url: str source: str publish_time: Optional[datetime] = None author: Optional[str] = None keywords: Optional[List[str]] = None stock_codes: Optional[List[str]] = None summary: Optional[str] = None raw_html: Optional[str] = None # 原始 HTML 内容 def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { "title": self.title, "content": self.content, "url": self.url, "source": self.source, "publish_time": self.publish_time.isoformat() if self.publish_time else None, "author": self.author, "keywords": self.keywords, "stock_codes": self.stock_codes, "summary": self.summary, "raw_html": self.raw_html, } class BaseCrawler(BaseTool): """ 爬虫基类 继承自 AgenticX BaseTool """ # 股票相关URL关键词 STOCK_URL_KEYWORDS = [ '/stock/', '/gupiao/', '/securities/', '/zhengquan/', '/a-shares/', '/ashares/', '/equity/', '/shares/', '/market/', '/listed/', '/ipo/' ] # 股票相关标题关键词 STOCK_TITLE_KEYWORDS = [ '股票', 'A股', 'a股', '上市', '个股', '涨停', '跌停', 'IPO', 'ipo', '新股', '配股', '增发', '重组', '并购', '股东', '董事', '证券', '港股', '科创板', '创业板', '主板', '中小板', '北交所', '沪市', '深市', '股价', '股份', '停牌', '复牌', '退市', '借壳' ] def __init__(self, name: str = "base_crawler", description: str = "Base crawler for financial news"): # 创建 ToolMetadata metadata = ToolMetadata( name=name, description=description, category=ToolCategory.DATA_ACCESS, version="1.0.0" ) super().__init__(metadata=metadata) # 爬虫特定配置 self.user_agent = settings.CRAWLER_USER_AGENT self.timeout = settings.CRAWLER_TIMEOUT self.max_retries = settings.CRAWLER_MAX_RETRIES self.delay = settings.CRAWLER_DELAY self.session = requests.Session() self.session.headers.update({'User-Agent': self.user_agent}) def _fetch_page(self, url: str) -> requests.Response: """ 获取网页内容(带重试机制,但503错误不重试) Args: url: 目标URL Returns: 响应对象 """ max_retries = 3 for attempt in range(max_retries): try: response = self.session.get(url, timeout=self.timeout) # 对于503错误,不重试,直接抛出(让调用者处理) if response.status_code == 503: logger.debug(f"503 error for {url}, skipping retry (server overloaded)") response.raise_for_status() response.raise_for_status() # 修复编码问题:优先使用 apparent_encoding,如果检测失败则尝试常见编码 if response.encoding is None or response.encoding == 'ISO-8859-1': # 尝试检测真实编码 if response.apparent_encoding: response.encoding = response.apparent_encoding else: # 对于中文网站,尝试常见编码 encodings = ['utf-8', 'gb2312', 'gbk', 'gb18030'] for enc in encodings: try: # 尝试解码验证 response.content.decode(enc) response.encoding = enc break except (UnicodeDecodeError, LookupError): continue else: # 如果都失败,默认使用 utf-8 response.encoding = 'utf-8' time.sleep(self.delay) # 请求间隔 return response except requests.exceptions.HTTPError as e: # 503错误不重试,直接抛出 if e.response and e.response.status_code == 503: logger.debug(f"503 error for {url}, not retrying") raise # 其他HTTP错误,重试 if attempt < max_retries - 1: wait_time = min(2 ** attempt, 10) logger.warning(f"HTTP error fetching {url} (attempt {attempt + 1}/{max_retries}): {e}, retrying in {wait_time}s...") time.sleep(wait_time) else: logger.error(f"HTTP error fetching {url} after {max_retries} attempts: {e}") raise except Exception as e: # 其他错误,重试 if attempt < max_retries - 1: wait_time = min(2 ** attempt, 10) logger.warning(f"Error fetching {url} (attempt {attempt + 1}/{max_retries}): {e}, retrying in {wait_time}s...") time.sleep(wait_time) else: logger.error(f"Failed to fetch {url} after {max_retries} attempts: {e}") raise # 理论上不会到达这里 raise Exception(f"Failed to fetch {url} after {max_retries} attempts") def _parse_html(self, html: str) -> BeautifulSoup: """ 解析HTML Args: html: HTML字符串 Returns: BeautifulSoup对象 """ return BeautifulSoup(html, 'lxml') def _extract_chinese_ratio(self, text: str) -> float: """ 计算中文字符比例 Args: text: 文本 Returns: 中文字符比例(0-1) """ import re pattern = re.compile(r'[\u4e00-\u9fa5]+') chinese_chars = pattern.findall(text) chinese_count = sum(len(chars) for chars in chinese_chars) total_count = len(text) return chinese_count / total_count if total_count > 0 else 0 def _clean_text(self, text: str) -> str: """ 清理文本 Args: text: 原始文本 Returns: 清理后的文本 """ import re # 移除HTML标签 text = re.sub(r'<[^>]+>', '', text) # 移除特殊空格 text = text.replace('\u3000', ' ') # 移除多余空格和换行 text = ' '.join(text.split()) return text.strip() def _extract_article_content(self, soup: BeautifulSoup, selectors: List[dict] = None) -> str: """ 通用智能内容提取方法 Args: soup: BeautifulSoup对象 selectors: 可选的自定义选择器列表 Returns: 提取的正文内容 """ import re # 默认选择器(按优先级排序) default_selectors = [ # 文章主体选择器 {'class': re.compile(r'article[-_]?(body|content|text|main)', re.I)}, {'class': re.compile(r'content[-_]?(article|body|text|main)', re.I)}, {'class': re.compile(r'main[-_]?(content|body|text|article)', re.I)}, {'class': re.compile(r'^(article|content|body|text|post)$', re.I)}, {'itemprop': 'articleBody'}, {'id': re.compile(r'(article|content|body|text)[-_]?(content|body|text)?', re.I)}, # 通用选择器 {'class': 'g-article-content'}, {'class': 'article-content'}, {'class': 'news-content'}, {'id': 'contentText'}, ] all_selectors = (selectors or []) + default_selectors for selector in all_selectors: content_div = soup.find(['div', 'article', 'section', 'main'], selector) if content_div: # 移除无关元素 for tag in content_div.find_all(['script', 'style', 'iframe', 'ins', 'noscript', 'nav', 'footer', 'header']): tag.decompose() for ad in content_div.find_all(class_=re.compile(r'(ad|advertisement|banner|recommend|related|share|comment)', re.I)): ad.decompose() # 提取所有段落(不限制数量) paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content and len(content) > 50: return self._clean_text(content) # 如果没有 p 标签,直接取文本 text = content_div.get_text(separator='\n', strip=True) if text and len(text) > 50: return self._clean_text(text) # 后备方案:取所有符合条件的段落(不限制数量) paragraphs = soup.find_all('p') if paragraphs: valid_paragraphs = [ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) and len(p.get_text(strip=True)) > 15 and not any(kw in p.get_text(strip=True).lower() for kw in ['copyright', '版权', '广告', 'advertisement']) ] content = '\n'.join(valid_paragraphs) if content: return self._clean_text(content) return "" def _is_stock_related_by_url(self, url: str) -> bool: """ 根据URL路径判断是否为股票相关新闻 Args: url: 新闻URL Returns: 是否为股票相关 """ url_lower = url.lower() return any(keyword in url_lower for keyword in self.STOCK_URL_KEYWORDS) def _is_stock_related_by_title(self, title: str) -> bool: """ 根据标题关键词判断是否为股票相关新闻 Args: title: 新闻标题 Returns: 是否为股票相关 """ return any(keyword in title for keyword in self.STOCK_TITLE_KEYWORDS) def _filter_stock_news(self, news_list: List[NewsItem]) -> List[NewsItem]: """ 筛选股票相关新闻 组合URL路径和标题关键词两种策略 策略调整: - 如果过滤后没有新闻,返回所有新闻(避免过度过滤) - 对于财经类网站,放宽筛选条件 Args: news_list: 原始新闻列表 Returns: 股票相关新闻列表 """ filtered_news = [] url_matched = 0 title_matched = 0 filtered_out = 0 for news in news_list: # URL匹配 或 标题匹配 url_match = self._is_stock_related_by_url(news.url) title_match = self._is_stock_related_by_title(news.title) if url_match or title_match: filtered_news.append(news) if url_match: url_matched += 1 if title_match: title_matched += 1 logger.debug(f"✓ Stock news matched: {news.title[:50]}... (URL:{url_match}, Title:{title_match})") else: filtered_out += 1 # 只记录前5条被过滤的,避免日志过多 if filtered_out <= 5: logger.debug(f"✗ Filtered out: {news.title[:50]}...") logger.info(f"Stock filter [{self.SOURCE_NAME}]: {len(news_list)} -> {len(filtered_news)} items " f"(URL matched: {url_matched}, Title matched: {title_matched}, Filtered: {filtered_out})") # 如果过滤后没有新闻,返回所有新闻(避免过度过滤) # 这对于财经类网站特别重要,因为它们的新闻通常都与金融相关 if len(news_list) > 0 and len(filtered_news) == 0: logger.warning(f"⚠️ All {len(news_list)} news items were filtered out for source {self.SOURCE_NAME}. " f"Returning all news to avoid over-filtering.") return news_list return filtered_news def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取新闻 Args: start_page: 起始页 end_page: 结束页 Returns: 新闻列表 """ raise NotImplementedError("Subclass must implement crawl method") def _setup_parameters(self): """设置工具参数(AgenticX 要求)""" pass # 爬虫不需要特殊参数设置 def execute(self, **kwargs) -> Dict[str, Any]: """ 同步执行方法(AgenticX Tool 协议要求) Args: **kwargs: 参数字典 - start_page: 起始页 - end_page: 结束页 Returns: 执行结果 """ start_page = kwargs.get('start_page', 1) end_page = kwargs.get('end_page', 1) logger.info(f"Crawling from page {start_page} to {end_page}") news_list = self.crawl(start_page, end_page) return { "success": True, "count": len(news_list), "news_list": [news.to_dict() for news in news_list], } async def aexecute(self, **kwargs) -> Dict[str, Any]: """ 异步执行方法(AgenticX Tool 协议要求) 当前实现为同步执行的包装 Args: **kwargs: 参数字典 Returns: 执行结果 """ return self.execute(**kwargs) ================================================ FILE: backend/app/tools/crawler_enhanced.py ================================================ """ 增强版爬虫模块 整合 deer-flow、BasicWebCrawler 和现有爬虫的优点 特性: 1. 多引擎支持:本地爬取 + Jina Reader API + Playwright JS 渲染 2. 智能内容提取:readabilipy + 启发式算法 3. 网站特定配置 4. 内容质量评估与自动重试 5. 缓存和去重 6. 统一 Article 模型,支持 LLM 消息格式 """ import re import os import json import time import hashlib import logging from typing import List, Dict, Any, Optional, Literal from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from urllib.parse import urlparse, urljoin import requests from bs4 import BeautifulSoup from tenacity import retry, stop_after_attempt, wait_exponential # 可选依赖 try: from markdownify import markdownify as md except ImportError: md = None try: from readabilipy import simple_json_from_html_string except ImportError: simple_json_from_html_string = None try: from playwright.sync_api import sync_playwright except ImportError: sync_playwright = None logger = logging.getLogger(__name__) # ============ 配置 ============ # 财经新闻网站特定配置 FINANCE_SITE_CONFIGS = { # 新浪财经 'finance.sina.com.cn': { 'main_content_selectors': [ '.article-content', '.article', '#artibody', '.main-content', '.post-body' ], 'title_selectors': ['h1.main-title', 'h1', '.article-title'], 'time_selectors': ['.date', '.pub_date', '.time-source'], 'needs_js': False, 'headers': { 'Referer': 'https://finance.sina.com.cn/', } }, # 东方财富 'eastmoney.com': { 'main_content_selectors': [ '.article-content', '#ContentBody', '.newsContent', '.article', '.content-article' ], 'needs_js': True, 'wait_selectors': ['.article-content', '#ContentBody'], }, # 每经网 'nbd.com.cn': { 'main_content_selectors': [ '.article-content', '.g-article-content', '.article-detail', '.post-content' ], 'needs_js': False, }, # 财新 'caixin.com': { 'main_content_selectors': [ '#Main_Content_Val', '.article-content', '.articleBody', '.main-content' ], 'needs_cookies': True, # 付费内容 'needs_js': False, }, # 腾讯财经 'finance.qq.com': { 'main_content_selectors': [ '.content-article', '.Cnt-Main-Article-QQ', '#Cnt-Main-Article-QQ', '.article-content' ], 'needs_js': False, }, # 21世纪经济报道 '21jingji.com': { 'main_content_selectors': [ '.article-content', '.detailContent', '.article-body', '.post-content' ], 'needs_js': False, }, # 默认配置 'default': { 'main_content_selectors': [ 'article', 'main', '.article', '.content', '.post-content', '.entry-content', '#content' ], 'needs_js': False, 'headers': {} } } # ============ Article 模型 ============ @dataclass class Article: """ 统一的文章模型(参考 deer-flow) 支持转换为 Markdown 和 LLM 消息格式 """ title: str content: str # 纯文本内容 html_content: Optional[str] = None # 原始 HTML url: str = "" source: str = "" publish_time: Optional[datetime] = None author: Optional[str] = None keywords: List[str] = field(default_factory=list) stock_codes: List[str] = field(default_factory=list) images: List[str] = field(default_factory=list) # 元数据 crawl_time: datetime = field(default_factory=datetime.utcnow) engine_used: str = "" # 使用的爬取引擎 quality_score: float = 0.0 # 内容质量评分 def to_markdown(self, include_title: bool = True, include_meta: bool = False) -> str: """转换为 Markdown 格式""" parts = [] if include_title and self.title: parts.append(f"# {self.title}\n") if include_meta: meta = [] if self.source: meta.append(f"来源: {self.source}") if self.publish_time: meta.append(f"时间: {self.publish_time.strftime('%Y-%m-%d %H:%M')}") if self.author: meta.append(f"作者: {self.author}") if self.url: meta.append(f"原文: {self.url}") if meta: parts.append(f"*{' | '.join(meta)}*\n") # 如果有 HTML 内容且安装了 markdownify,转换它 if self.html_content and md: parts.append(md(self.html_content)) else: parts.append(self.content) return "\n".join(parts) def to_llm_message(self) -> List[Dict[str, Any]]: """ 转换为 LLM 消息格式(参考 deer-flow) 将图片和文本分离,便于多模态 LLM 处理 """ content: List[Dict[str, str]] = [] markdown = self.to_markdown() if not markdown.strip(): return [{"type": "text", "text": "No content available"}] # 提取图片 URL image_pattern = r"!\[.*?\]\((.*?)\)" parts = re.split(image_pattern, markdown) for i, part in enumerate(parts): if i % 2 == 1: # 图片 URL image_url = urljoin(self.url, part.strip()) content.append({ "type": "image_url", "image_url": {"url": image_url} }) else: # 文本内容 text_part = part.strip() if text_part: content.append({"type": "text", "text": text_part}) return content if content else [{"type": "text", "text": "No content available"}] def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { "title": self.title, "content": self.content, "html_content": self.html_content, "url": self.url, "source": self.source, "publish_time": self.publish_time.isoformat() if self.publish_time else None, "author": self.author, "keywords": self.keywords, "stock_codes": self.stock_codes, "images": self.images, "crawl_time": self.crawl_time.isoformat(), "engine_used": self.engine_used, "quality_score": self.quality_score, } # ============ 内容提取器 ============ class ContentExtractor: """ 智能内容提取器 结合 readabilipy 和启发式算法 """ @staticmethod def extract_with_readability(html: str) -> Optional[Article]: """使用 readabilipy 提取(参考 deer-flow)""" if simple_json_from_html_string is None: return None try: result = simple_json_from_html_string(html, use_readability=True) content = result.get("content", "") title = result.get("title", "Untitled") if not content or len(content.strip()) < 100: return None return Article( title=title, content=BeautifulSoup(content, 'html.parser').get_text(separator='\n', strip=True), html_content=content, ) except Exception as e: logger.warning(f"Readability extraction failed: {e}") return None @staticmethod def extract_with_selectors(soup: BeautifulSoup, config: dict) -> Optional[Article]: """使用 CSS 选择器提取""" # 提取标题 title = None for sel in config.get('title_selectors', ['h1', 'title']): el = soup.select_one(sel) if el: title = el.get_text(strip=True) break if not title: title_el = soup.find('title') title = title_el.get_text(strip=True) if title_el else "Untitled" # 提取主要内容 content_el = None for sel in config.get('main_content_selectors', []): content_el = soup.select_one(sel) if content_el and len(content_el.get_text(strip=True)) > 100: break if not content_el: return None # 清理内容 for tag in content_el.find_all(['script', 'style', 'nav', 'footer', 'aside']): tag.decompose() content = content_el.get_text(separator='\n', strip=True) html_content = str(content_el) if len(content) < 100: return None return Article( title=title, content=content, html_content=html_content, ) @staticmethod def extract_heuristic(soup: BeautifulSoup) -> Optional[Article]: """ 启发式内容提取(参考 BasicWebCrawler) 找到包含最多段落文本的元素 """ # 提取标题 title_el = soup.find('title') title = title_el.get_text(strip=True) if title_el else "Untitled" # 排除导航等元素 for tag in soup.find_all(['script', 'style', 'nav', 'footer', 'aside', 'header', '.sidebar', '.advertisement']): if hasattr(tag, 'decompose'): tag.decompose() # 找到最佳内容容器 candidates = [] for tag in ['article', 'main', 'section', 'div']: for elem in soup.find_all(tag): # 排除导航、侧边栏等 elem_class = ' '.join(elem.get('class', [])).lower() elem_id = (elem.get('id') or '').lower() exclude_keywords = ['nav', 'sidebar', 'footer', 'header', 'menu', 'ad', 'banner', 'comment'] if any(kw in elem_class or kw in elem_id for kw in exclude_keywords): continue text = elem.get_text(strip=True) text_len = len(text) if text_len > 200: score = text_len # 有标题加分 if elem.find(['h1', 'h2', 'h3']): score += 1000 # 有段落加分 p_count = len(elem.find_all('p')) score += p_count * 50 candidates.append((elem, score, text_len)) if not candidates: return None # 选择得分最高的 best_elem = max(candidates, key=lambda x: x[1])[0] content = best_elem.get_text(separator='\n', strip=True) return Article( title=title, content=content, html_content=str(best_elem), ) @classmethod def extract(cls, html: str, url: str = "", config: dict = None) -> Article: """ 智能提取:依次尝试多种方法 1. readabilipy(最智能) 2. CSS 选择器(网站特定) 3. 启发式算法(兜底) """ soup = BeautifulSoup(html, 'html.parser') config = config or FINANCE_SITE_CONFIGS.get('default', {}) # 方法 1: readabilipy article = cls.extract_with_readability(html) if article and article.quality_score > 0.5: article.engine_used = "readability" return article # 方法 2: CSS 选择器 article = cls.extract_with_selectors(soup, config) if article: article.engine_used = "selectors" return article # 方法 3: 启发式 article = cls.extract_heuristic(soup) if article: article.engine_used = "heuristic" return article # 兜底:返回整个 body body = soup.find('body') return Article( title=soup.title.string if soup.title else "Untitled", content=body.get_text(separator='\n', strip=True) if body else "", html_content=str(body) if body else "", engine_used="fallback", ) # ============ 爬取引擎 ============ class JinaReaderEngine: """ Jina Reader API 引擎(参考 deer-flow) https://jina.ai/reader """ API_URL = "https://r.jina.ai/" def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or os.getenv("JINA_API_KEY") def crawl(self, url: str, return_format: str = "html") -> Optional[str]: """爬取 URL""" headers = { "Content-Type": "application/json", "X-Return-Format": return_format, } if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" try: response = requests.post( self.API_URL, headers=headers, json={"url": url}, timeout=30 ) if response.status_code != 200: logger.error(f"Jina API error: {response.status_code}") return None return response.text except Exception as e: logger.error(f"Jina crawl failed: {e}") return None class PlaywrightEngine: """ Playwright 浏览器引擎(参考 BasicWebCrawler) 支持 JS 渲染 """ def __init__(self, headless: bool = True): self.headless = headless def crawl(self, url: str, wait_selectors: List[str] = None, timeout_ms: int = 15000) -> Optional[str]: """使用 Playwright 爬取""" if sync_playwright is None: logger.warning("Playwright not installed") return None try: with sync_playwright() as p: browser = p.chromium.launch( headless=self.headless, args=['--disable-blink-features=AutomationControlled'] ) context = browser.new_context( user_agent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' 'AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36', viewport={'width': 1920, 'height': 1080}, ) # 反检测 context.add_init_script(""" Object.defineProperty(navigator, 'webdriver', { get: () => undefined }); """) page = context.new_page() page.goto(url, wait_until='networkidle', timeout=timeout_ms) # 等待选择器 if wait_selectors: for sel in wait_selectors: try: page.wait_for_selector(sel, timeout=5000) break except Exception: continue # 等待内容稳定 page.wait_for_timeout(1000) content = page.content() context.close() browser.close() return content except Exception as e: logger.error(f"Playwright crawl failed: {e}") return None class RequestsEngine: """ 基础 Requests 引擎 """ DEFAULT_HEADERS = { 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' 'AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8', } def __init__(self, timeout: int = 20): self.timeout = timeout self.session = requests.Session() self.session.headers.update(self.DEFAULT_HEADERS) @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10)) def crawl(self, url: str, headers: dict = None, cookies: dict = None) -> Optional[str]: """爬取 URL""" try: response = self.session.get( url, headers=headers, cookies=cookies, timeout=self.timeout ) response.raise_for_status() response.encoding = response.apparent_encoding return response.text except Exception as e: logger.error(f"Requests crawl failed: {e}") raise # ============ 缓存 ============ class CrawlCache: """ 爬取缓存(参考 BasicWebCrawler) """ def __init__(self, cache_dir: str = ".crawl_cache", ttl_hours: int = 24): self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.ttl_seconds = ttl_hours * 3600 def _key(self, url: str) -> str: return hashlib.md5(url.encode()).hexdigest() def get(self, url: str) -> Optional[str]: """获取缓存""" key = self._key(url) cache_file = self.cache_dir / f"{key}.json" if not cache_file.exists(): return None try: data = json.loads(cache_file.read_text(encoding='utf-8')) cached_time = datetime.fromisoformat(data['time']) if (datetime.utcnow() - cached_time).total_seconds() > self.ttl_seconds: cache_file.unlink() # 过期删除 return None return data['html'] except Exception: return None def set(self, url: str, html: str): """设置缓存""" key = self._key(url) cache_file = self.cache_dir / f"{key}.json" try: data = { 'url': url, 'time': datetime.utcnow().isoformat(), 'html': html, } cache_file.write_text(json.dumps(data, ensure_ascii=False), encoding='utf-8') except Exception as e: logger.warning(f"Cache write failed: {e}") # ============ 主爬虫类 ============ class EnhancedCrawler: """ 增强版爬虫 自动选择最佳引擎,智能提取内容 """ def __init__( self, use_cache: bool = True, cache_ttl_hours: int = 24, jina_api_key: Optional[str] = None, default_engine: Literal['requests', 'playwright', 'jina'] = 'requests' ): self.use_cache = use_cache self.cache = CrawlCache(ttl_hours=cache_ttl_hours) if use_cache else None # 初始化引擎 self.requests_engine = RequestsEngine() self.playwright_engine = PlaywrightEngine() self.jina_engine = JinaReaderEngine(api_key=jina_api_key) self.default_engine = default_engine def _get_site_config(self, url: str) -> dict: """获取网站配置""" domain = urlparse(url).netloc for site_domain, config in FINANCE_SITE_CONFIGS.items(): if site_domain in domain: return config return FINANCE_SITE_CONFIGS['default'] def _evaluate_quality(self, article: Article) -> float: """ 评估内容质量 返回 0-1 的分数 """ score = 0.0 # 内容长度 content_len = len(article.content) if content_len > 500: score += 0.3 elif content_len > 200: score += 0.2 elif content_len > 100: score += 0.1 # 有标题 if article.title and article.title != "Untitled": score += 0.2 # 中文内容比例(财经新闻应该主要是中文) chinese_pattern = re.compile(r'[\u4e00-\u9fa5]') chinese_count = len(chinese_pattern.findall(article.content)) if content_len > 0: chinese_ratio = chinese_count / content_len if chinese_ratio > 0.5: score += 0.3 elif chinese_ratio > 0.3: score += 0.2 # 段落结构 paragraph_count = article.content.count('\n') if paragraph_count > 5: score += 0.2 elif paragraph_count > 2: score += 0.1 return min(score, 1.0) def crawl( self, url: str, engine: Optional[Literal['requests', 'playwright', 'jina', 'auto']] = None, force_refresh: bool = False ) -> Article: """ 爬取单个 URL Args: url: 目标 URL engine: 爬取引擎 ('requests', 'playwright', 'jina', 'auto') force_refresh: 是否强制刷新缓存 Returns: Article 对象 """ # 检查缓存 if self.use_cache and not force_refresh: cached_html = self.cache.get(url) if cached_html: logger.info(f"Using cached content for {url}") article = ContentExtractor.extract(cached_html, url) article.url = url article.quality_score = self._evaluate_quality(article) return article # 获取网站配置 config = self._get_site_config(url) engine = engine or self.default_engine html = None used_engine = engine # 自动选择引擎 if engine == 'auto': if config.get('needs_js'): engine = 'playwright' else: engine = 'requests' # 爬取 if engine == 'requests': html = self.requests_engine.crawl( url, headers=config.get('headers'), cookies=config.get('cookies') ) used_engine = 'requests' elif engine == 'playwright': html = self.playwright_engine.crawl( url, wait_selectors=config.get('wait_selectors') ) used_engine = 'playwright' elif engine == 'jina': html = self.jina_engine.crawl(url) used_engine = 'jina' # 如果主引擎失败,尝试备用引擎 if not html or len(html) < 500: logger.warning(f"Primary engine failed, trying fallback...") if used_engine != 'jina' and self.jina_engine.api_key: html = self.jina_engine.crawl(url) used_engine = 'jina' if not html and used_engine != 'playwright' and sync_playwright: html = self.playwright_engine.crawl(url) used_engine = 'playwright' if not html: logger.error(f"All engines failed for {url}") return Article( title="Crawl Failed", content=f"Failed to crawl {url}", url=url, engine_used="none", quality_score=0.0 ) # 缓存 if self.use_cache: self.cache.set(url, html) # 提取内容 article = ContentExtractor.extract(html, url, config) article.url = url article.source = urlparse(url).netloc article.engine_used = used_engine article.quality_score = self._evaluate_quality(article) # 质量检查:如果质量太低且没用过 Jina,尝试用 Jina if article.quality_score < 0.3 and used_engine != 'jina' and self.jina_engine.api_key: logger.info(f"Low quality ({article.quality_score:.2f}), retrying with Jina...") jina_html = self.jina_engine.crawl(url) if jina_html: jina_article = ContentExtractor.extract(jina_html, url, config) jina_article.quality_score = self._evaluate_quality(jina_article) if jina_article.quality_score > article.quality_score: article = jina_article article.engine_used = 'jina' return article def crawl_batch( self, urls: List[str], engine: Optional[str] = None, delay: float = 1.0 ) -> List[Article]: """ 批量爬取 Args: urls: URL 列表 engine: 爬取引擎 delay: 请求间隔(秒) Returns: Article 列表 """ articles = [] for i, url in enumerate(urls): logger.info(f"Crawling {i+1}/{len(urls)}: {url}") try: article = self.crawl(url, engine=engine) articles.append(article) except Exception as e: logger.error(f"Failed to crawl {url}: {e}") articles.append(Article( title="Crawl Failed", content=str(e), url=url, quality_score=0.0 )) if delay > 0 and i < len(urls) - 1: time.sleep(delay) return articles # ============ 便捷函数 ============ # 全局爬虫实例 _crawler: Optional[EnhancedCrawler] = None def get_crawler() -> EnhancedCrawler: """获取全局爬虫实例""" global _crawler if _crawler is None: _crawler = EnhancedCrawler() return _crawler def crawl_url(url: str, engine: str = 'auto') -> Article: """便捷函数:爬取单个 URL""" return get_crawler().crawl(url, engine=engine) def crawl_urls(urls: List[str], engine: str = 'auto') -> List[Article]: """便捷函数:批量爬取""" return get_crawler().crawl_batch(urls, engine=engine) # ============ 测试 ============ if __name__ == "__main__": logging.basicConfig(level=logging.INFO) # 测试爬取 test_urls = [ "https://finance.sina.com.cn/roll/c/56592.shtml", ] crawler = EnhancedCrawler(use_cache=True) for url in test_urls: print(f"\n{'='*60}") print(f"Crawling: {url}") article = crawler.crawl(url, engine='auto') print(f"Title: {article.title}") print(f"Engine: {article.engine_used}") print(f"Quality: {article.quality_score:.2f}") print(f"Content length: {len(article.content)}") print(f"Preview: {article.content[:200]}...") ================================================ FILE: backend/app/tools/dynamic_crawler_example.py ================================================ """ 动态网站爬虫示例 - 使用 Selenium 适用于需要点击"加载更多"的网站 依赖安装: pip install selenium webdriver-manager """ import logging from typing import List, Optional from datetime import datetime from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.service import Service from webdriver_manager.chrome import ChromeDriverManager from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class DynamicCrawlerExample(BaseCrawler): """ 动态网站爬虫示例 支持点击"加载更多"按钮 """ BASE_URL = "https://www.eeo.com.cn/" STOCK_URL = "https://www.eeo.com.cn/jg/jinrong/zhengquan/" SOURCE_NAME = "eeo_dynamic" def __init__(self): super().__init__( name="eeo_dynamic_crawler", description="Crawl EEO with dynamic loading support" ) self.driver = None def _init_driver(self): """初始化 Selenium WebDriver""" if self.driver: return chrome_options = Options() chrome_options.add_argument('--headless') # 无头模式 chrome_options.add_argument('--no-sandbox') chrome_options.add_argument('--disable-dev-shm-usage') chrome_options.add_argument(f'user-agent={self.user_agent}') service = Service(ChromeDriverManager().install()) self.driver = webdriver.Chrome(service=service, options=chrome_options) logger.info("Selenium WebDriver initialized") def _close_driver(self): """关闭 WebDriver""" if self.driver: self.driver.quit() self.driver = None def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取新闻(支持动态加载) Args: start_page: 起始页(对于点击加载更多的网站,这个参数表示点击次数) end_page: 结束页 Returns: 新闻列表 """ news_list = [] try: self._init_driver() page_news = self._crawl_with_selenium() news_list.extend(page_news) logger.info(f"Crawled EEO (dynamic), got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling EEO (dynamic): {e}") finally: self._close_driver() # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_with_selenium(self) -> List[NewsItem]: """使用 Selenium 爬取动态加载的内容""" news_items = [] try: # 1. 访问页面 self.driver.get(self.STOCK_URL) logger.info(f"Loaded page: {self.STOCK_URL}") # 2. 等待页面加载 WebDriverWait(self.driver, 10).until( EC.presence_of_element_located((By.TAG_NAME, "body")) ) # 3. 尝试点击"加载更多"按钮(如果存在) click_count = 0 max_clicks = 3 # 最多点击3次"加载更多" while click_count < max_clicks: try: # 查找"加载更多"按钮(根据实际页面调整选择器) load_more_button = self.driver.find_element( By.XPATH, "//button[contains(text(), '加载更多')] | //div[contains(text(), '点击加载更多')]" ) # 滚动到按钮位置 self.driver.execute_script("arguments[0].scrollIntoView();", load_more_button) # 点击按钮 load_more_button.click() click_count += 1 logger.info(f"Clicked 'Load More' button {click_count} times") # 等待新内容加载 import time time.sleep(2) except Exception as e: logger.debug(f"No more 'Load More' button or click failed: {e}") break # 4. 提取所有新闻链接 news_links = self._extract_news_links_from_selenium() logger.info(f"Found {len(news_links)} news links") # 5. 爬取每条新闻的详情 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error in Selenium crawling: {e}") return news_items def _extract_news_links_from_selenium(self) -> List[dict]: """从 Selenium 页面中提取新闻链接""" news_links = [] try: # 查找所有新闻链接(根据实际页面结构调整选择器) link_elements = self.driver.find_elements(By.CSS_SELECTOR, "a[href*='/article/']") for element in link_elements: try: href = element.get_attribute('href') title = element.text.strip() if href and title and href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) except Exception as e: continue except Exception as e: logger.error(f"Error extracting links: {e}") return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情(使用传统 requests 方式)""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) soup = self._parse_html(response.text) # 提取正文(简化示例) content_div = soup.find('div', class_='article-content') if content_div: content = content_div.get_text(strip=True) else: content = "" if not content: return None return NewsItem( title=title, content=self._clean_text(content), url=url, source=self.SOURCE_NAME, publish_time=datetime.now(), ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None # 使用示例 if __name__ == "__main__": crawler = DynamicCrawlerExample() news = crawler.crawl() print(f"Crawled {len(news)} news items") for item in news[:5]: print(f"- {item.title}") ================================================ FILE: backend/app/tools/eastmoney_crawler.py ================================================ """ 东方财富爬虫工具 目标URL: https://stock.eastmoney.com/ """ import re import logging from typing import List, Optional from datetime import datetime from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class EastmoneyCrawlerTool(BaseCrawler): """ 东方财富爬虫 主要爬取股市新闻 """ BASE_URL = "https://stock.eastmoney.com/" STOCK_URL = "https://stock.eastmoney.com/news/" SOURCE_NAME = "eastmoney" def __init__(self): super().__init__( name="eastmoney_crawler", description="Crawl financial news from East Money (eastmoney.com)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取东方财富新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled Eastmoney, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling Eastmoney: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: # 尝试爬取新闻栏目或主页 try: response = self._fetch_page(self.STOCK_URL) except: response = self._fetch_page(self.BASE_URL) # 东方财富编码处理 if response.encoding == 'ISO-8859-1' or not response.encoding: response.encoding = 'utf-8' soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) # 东方财富新闻URL模式(扩展更多模式) eastmoney_patterns = [ '/news/', # 新闻频道 '/stock/', # 股票频道 '/a/', # 文章 '/article/', # 文章 '.html', # HTML页面 '/guba/', # 股吧 ] for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 检查是否匹配东方财富URL模式 is_eastmoney_url = False # 方式1: 检查是否包含eastmoney.com域名 if 'eastmoney.com' in href or 'eastmoney.cn' in href: for pattern in eastmoney_patterns: if pattern in href: is_eastmoney_url = True break # 方式2: 相对路径且匹配模式 if not is_eastmoney_url and href.startswith('/'): for pattern in eastmoney_patterns: if pattern in href: is_eastmoney_url = True break # 方式3: 检查data属性或class中包含新闻标识 if not is_eastmoney_url: link_class = link.get('class', []) if isinstance(link_class, list): link_class_str = ' '.join(link_class) else: link_class_str = str(link_class) if any(kw in link_class_str.lower() for kw in ['news', 'article', 'item', 'title']): if any(pattern in href for pattern in ['/a/', '/news/', '.html']): is_eastmoney_url = True if is_eastmoney_url and title and len(title.strip()) > 5: # 确保是完整URL if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): # 判断是stock还是www域名 if '/stock/' in href or '/guba/' in href: href = 'https://stock.eastmoney.com' + href else: href = 'https://www.eastmoney.com' + href elif not href.startswith('http'): href = 'https://stock.eastmoney.com/' + href.lstrip('/') # 过滤掉明显不是新闻的链接 if any(skip in href.lower() for skip in ['javascript:', 'mailto:', '#', 'void(0)', '/guba/']): continue if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title.strip()}) logger.debug(f"Eastmoney: Found {len(news_links)} potential news links") return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'Body'}, {'id': 'ContentBody'}, {'class': 'article-content'}, {'class': 'newsContent'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('div', {'class': re.compile(r'time|date')}) if not time_elem: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('div', {'class': re.compile(r'author|source')}) if not author_elem: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/eeo_crawler.py ================================================ """ 经济观察网爬虫工具 目标URL: https://www.eeo.com.cn/jg/jinrong/zhengquan/ """ import re import json import logging from typing import List, Optional from datetime import datetime from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class EeoCrawlerTool(BaseCrawler): """ 经济观察网爬虫 主要爬取证券栏目 使用官方API接口 """ BASE_URL = "https://www.eeo.com.cn/" # 证券栏目URL(用于获取uuid) STOCK_URL = "https://www.eeo.com.cn/jg/jinrong/zhengquan/" # API接口URL API_URL = "https://app.eeo.com.cn/" SOURCE_NAME = "eeo" # 证券频道的UUID(通过访问页面获取) CHANNEL_UUID = "9905934f8ec548ddae87652dbb9eebc6" def __init__(self): super().__init__( name="eeo_crawler", description="Crawl financial news from Economic Observer (eeo.com.cn)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取经济观察网新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled EEO, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling EEO: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _fetch_api_news(self, page: int = 0, prev_uuid: str = "", prev_publish_date: str = "") -> List[dict]: """ 通过API获取新闻列表 Args: page: 页码(从0开始) prev_uuid: 上一条新闻的UUID(用于翻页) prev_publish_date: 上一条新闻的发布时间(用于翻页) Returns: 新闻列表 """ try: # 构建API参数 params = { "app": "article", "controller": "index", "action": "getMoreArticle", "uuid": self.CHANNEL_UUID, "page": page, "pageSize": 20, # 每页20条 "prevUuid": prev_uuid, "prevPublishDate": prev_publish_date, } # 添加必要的请求头 headers = { "User-Agent": self.user_agent, "Referer": self.STOCK_URL, "Accept": "*/*", } response = self.session.get( self.API_URL, params=params, headers=headers, timeout=self.timeout ) response.raise_for_status() # 处理JSONP响应 # 响应格式可能是: jQuery11130...callback({"code":200,"data":[...]}) # 或者直接是JSON: {"code":200,"data":[...]} content = response.text.strip() logger.debug(f"[EEO] API response preview (first 300 chars): {content[:300]}") # 尝试1: 如果是JSONP格式,提取JSON部分 json_match = re.search(r'\((.*)\)$', content) if json_match: try: json_str = json_match.group(1) data = json.loads(json_str) # 支持两种格式:status==1 或 code==200 if (data.get('status') == 1 or data.get('code') == 200) and 'data' in data: logger.info(f"[EEO] Successfully parsed JSONP, found {len(data['data'])} items") return data['data'] except json.JSONDecodeError as e: logger.debug(f"[EEO] JSONP parse failed: {e}") pass # 尝试2: 直接解析JSON try: data = json.loads(content) if isinstance(data, dict): # 支持两种格式:status==1 或 code==200 if (data.get('status') == 1 or data.get('code') == 200) and 'data' in data: logger.info(f"[EEO] Successfully parsed JSON, found {len(data['data'])} items") return data['data'] elif isinstance(data, list): logger.info(f"[EEO] API returned list with {len(data)} items") return data except json.JSONDecodeError as e: logger.debug(f"[EEO] JSON parse failed: {e}") pass # 尝试3: 查找JSON对象(更宽松的匹配) json_obj_match = re.search(r'\{[^{}]*"(status|code)"[^{}]*"data"[^{}]*\}', content, re.DOTALL) if json_obj_match: try: data = json.loads(json_obj_match.group(0)) # 支持两种格式:status==1 或 code==200 if (data.get('status') == 1 or data.get('code') == 200) and 'data' in data: logger.info(f"[EEO] Successfully parsed with regex, found {len(data['data'])} items") return data['data'] except json.JSONDecodeError as e: logger.debug(f"[EEO] Regex parse failed: {e}") pass logger.warning(f"Failed to parse API response, content preview: {content[:200]}") return [] except Exception as e: logger.error(f"API fetch failed: {e}") return [] def _crawl_page(self, page: int) -> List[NewsItem]: """ 爬取单页新闻(使用API) Args: page: 页码 Returns: 新闻列表 """ news_items = [] try: # 使用API获取新闻列表 api_news_list = self._fetch_api_news(page=0) # 第一页 if not api_news_list: logger.warning("No news from API, fallback to HTML parsing") return self._crawl_page_html() logger.info(f"Fetched {len(api_news_list)} news from API") # 解析每条新闻 for news_data in api_news_list[:20]: # 限制20条 try: news_item = self._parse_api_news_item(news_data) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to parse news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _parse_api_news_item(self, news_data: dict) -> Optional[NewsItem]: """ 解析API返回的新闻数据 Args: news_data: API返回的单条新闻数据 Returns: NewsItem对象 """ try: # 提取基本信息 title = news_data.get('title', '').strip() url = news_data.get('url', '') # 确保URL是完整的 if url and not url.startswith('http'): url = 'https://www.eeo.com.cn' + url if not title or not url: return None # 提取发布时间(API返回的字段可能是 published 或 publishDate) publish_time_str = news_data.get('published', '') or news_data.get('publishDate', '') publish_time = self._parse_time_string(publish_time_str) if publish_time_str else datetime.now() # 提取作者 author = news_data.get('author', '') # 获取新闻详情(内容和原始HTML) content, raw_html = self._fetch_news_content(url) if not content: return None return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author if author else None, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to parse API news item: {e}") return None def _fetch_news_content(self, url: str) -> tuple: """ 获取新闻详情页内容 Args: url: 新闻详情页URL Returns: (新闻正文, 原始HTML) """ try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) return content, raw_html except Exception as e: logger.warning(f"Failed to fetch content from {url}: {e}") return "", "" def _crawl_page_html(self) -> List[NewsItem]: """ 备用方案:直接解析HTML页面(只能获取首屏内容) """ news_items = [] try: response = self._fetch_page(self.STOCK_URL) soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links from HTML") # 限制爬取数量 max_news = 10 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling HTML page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) # 经济观察网新闻URL模式(扩展更多模式) eeo_patterns = [ r'/\d{4}/', # 日期路径 /2024/ '.shtml', # 静态HTML '/jg/', # 经济观察 '/jinrong/', # 金融 '/zhengquan/', # 证券 '/article/', # 文章 ] for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 检查是否匹配经济观察网URL模式 is_eeo_url = False # 方式1: 检查URL模式 for pattern in eeo_patterns: if re.search(pattern, href): is_eeo_url = True break # 方式2: 检查是否包含eeo.com.cn域名 if 'eeo.com.cn' in href: is_eeo_url = True # 方式3: 检查链接的class或data属性 if not is_eeo_url: link_class = link.get('class', []) if isinstance(link_class, list): link_class_str = ' '.join(link_class) else: link_class_str = str(link_class) if any(kw in link_class_str.lower() for kw in ['news', 'article', 'item', 'title', 'list']): if href.startswith('/') or 'eeo.com.cn' in href: is_eeo_url = True if is_eeo_url and title and len(title.strip()) > 5: # 确保是完整URL if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.eeo.com.cn' + href elif not href.startswith('http'): href = 'https://www.eeo.com.cn/' + href.lstrip('/') # 过滤掉明显不是新闻的链接 if any(skip in href.lower() for skip in ['javascript:', 'mailto:', '#', 'void(0)', '/tag/', '/author/']): continue if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title.strip()}) logger.debug(f"EEO: Found {len(news_links)} potential news links from HTML") return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情(HTML方式)""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'article-content'}, {'class': 'content'}, {'id': 'articleContent'}, {'class': 'news-content'}, {'class': 'text_content'}, # 常见的正文类名 ] for selector in content_selectors: content_div = soup.find(['div', 'article'], selector) if content_div: # 1. 移除明确的噪音元素 for tag in content_div.find_all(['script', 'style', 'iframe', 'ins', 'select', 'input', 'button', 'form']): tag.decompose() # 2. 移除特定的广告和推荐块 for ad in content_div.find_all(class_=re.compile(r'ad|banner|share|otherContent|recommend|app-guide|qrcode', re.I)): ad.decompose() # 3. 获取所有文本,使用换行符分隔 # 关键修改:使用 get_text 而不是 find_all('p') full_text = content_div.get_text(separator='\n', strip=True) # 4. 按行分割并清洗 lines = full_text.split('\n') article_parts = [] for line in lines: line = line.strip() if not line: continue # 5. 简单的长度过滤,防止页码等噪音 if len(line) < 2: continue article_parts.append(line) if article_parts: content = '\n'.join(article_parts) return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/interactive_crawler.py ================================================ """ 交互式网页爬虫 使用 requests + BeautifulSoup 进行网页爬取 特别用于搜索结果补充,当 BochaAI 结果不足时使用 注意:主要搜索引擎(Bing、百度)都有反爬机制,本模块已做相应优化: 1. 模拟真实浏览器请求头 2. 检测验证页面并自动降级 3. 多引擎轮换备选 """ import logging import re import time import random from typing import List, Dict, Any, Optional from urllib.parse import quote_plus, urljoin, urlparse import requests from bs4 import BeautifulSoup logger = logging.getLogger(__name__) # 更完善的 User-Agent,模拟最新的 Chrome 浏览器 USER_AGENTS = [ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36', 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36', 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36', ] # 验证页面关键词(用于检测被拦截) CAPTCHA_KEYWORDS = [ '确认您是真人', '人机验证', 'captcha', 'verify you are human', '验证码', '请完成验证', '安全验证', '异常访问', '请输入验证码', '最后一步', '请解决以下难题' ] class InteractiveCrawler: """交互式网页爬虫(纯 requests 实现)""" def __init__(self, timeout: int = 15): """ 初始化爬虫 Args: timeout: 请求超时时间(秒) """ self.timeout = timeout self.session = requests.Session() self._user_agent = random.choice(USER_AGENTS) # 更完善的请求头,模拟真实浏览器 self.session.headers.update({ 'User-Agent': self._user_agent, 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7', 'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6', 'Accept-Encoding': 'gzip, deflate, br', 'Connection': 'keep-alive', 'Upgrade-Insecure-Requests': '1', 'Sec-Fetch-Dest': 'document', 'Sec-Fetch-Mode': 'navigate', 'Sec-Fetch-Site': 'none', 'Sec-Fetch-User': '?1', 'Cache-Control': 'max-age=0', 'sec-ch-ua': '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"', 'sec-ch-ua-mobile': '?0', 'sec-ch-ua-platform': '"macOS"', }) def _is_captcha_page(self, html_content: str, soup: BeautifulSoup = None) -> bool: """ 检测页面是否为验证码/人机验证页面 Args: html_content: HTML 原始内容 soup: 已解析的 BeautifulSoup 对象 Returns: True 如果是验证页面 """ text_to_check = html_content.lower() if soup: text_to_check = soup.get_text().lower() for keyword in CAPTCHA_KEYWORDS: if keyword.lower() in text_to_check: return True return False def search_on_bing( self, query: str, num_results: int = 10 ) -> List[Dict[str, str]]: """ 在 Bing 上搜索并获取结果 Args: query: 搜索关键词 num_results: 获取的结果数量 Returns: 搜索结果列表 [{"url": "...", "title": "...", "snippet": "..."}] """ results = [] try: # 使用国际版 Bing,中国版有更严格的反爬 search_url = f"https://www.bing.com/search?q={quote_plus(query)}&count={num_results}" logger.info(f"🔍 Bing 搜索: {query}") logger.debug(f"搜索URL: {search_url}") response = self.session.get(search_url, timeout=self.timeout) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # ========== 检测验证码页面 ========== if self._is_captcha_page(response.text, soup): logger.warning("⚠️ Bing 触发人机验证,跳过此引擎") return [] # 返回空,让调用者使用其他引擎 # ========== 调试:打印找到的元素 ========== # 尝试多种选择器 b_algo_items = soup.select('.b_algo') logger.info(f"📊 Bing HTML解析: .b_algo={len(b_algo_items)}个") # 如果 .b_algo 没找到,尝试其他选择器 if not b_algo_items: # 尝试查找所有包含链接的 li 元素 li_items = soup.select('#b_results > li') logger.info(f"📊 尝试 #b_results > li: {len(li_items)}个") # 打印页面中所有链接供调试 all_links = soup.select('a[href^="http"]') logger.info(f"📊 页面总链接数: {len(all_links)}个") # 打印前10个链接 for i, link in enumerate(all_links[:15]): href = link.get('href', '') text = link.get_text(strip=True)[:50] # 过滤掉 Bing 内部链接 if 'bing.com' not in href and 'microsoft.com' not in href: logger.info(f" 链接{i+1}: {text} -> {href[:80]}") # ========== 提取搜索结果 ========== # 方法1: 标准 .b_algo 选择器 for result in b_algo_items[:num_results]: try: # 提取标题和链接 title_elem = result.select_one('h2 a') if not title_elem: title_elem = result.select_one('a') # 备选 if not title_elem: continue title = title_elem.get_text(strip=True) url = title_elem.get('href', '') # 提取摘要 snippet_elem = result.select_one('.b_caption p, p') snippet = snippet_elem.get_text(strip=True) if snippet_elem else '' if url and title and 'bing.com' not in url: results.append({ "url": url, "title": title, "snippet": snippet[:300], "source": "bing" }) logger.debug(f" ✅ 提取: {title[:40]} -> {url[:60]}") except Exception as e: logger.debug(f"解析 Bing 结果失败: {e}") continue # 方法2: 如果 .b_algo 没有结果,可能是验证页面的残留链接,不再使用备选提取 if not results and b_algo_items: logger.info("⚠️ Bing 无有效结果") logger.info(f"✅ Bing 搜索完成,获得 {len(results)} 条结果") except requests.exceptions.Timeout: logger.warning(f"⚠️ Bing 搜索超时: {query}") except requests.exceptions.RequestException as e: logger.warning(f"⚠️ Bing 搜索请求失败: {e}") except Exception as e: logger.error(f"❌ Bing 搜索失败: {e}") return results def search_on_baidu( self, query: str, num_results: int = 10 ) -> List[Dict[str, str]]: """ 在百度上搜索并获取结果(百度对简单爬虫相对友好) Args: query: 搜索关键词 num_results: 获取的结果数量 Returns: 搜索结果列表 """ results = [] try: # 百度搜索 URL search_url = f"https://www.baidu.com/s?wd={quote_plus(query)}&rn={num_results}" logger.info(f"🔍 百度搜索: {query}") logger.debug(f"搜索URL: {search_url}") # 百度需要特定的请求头 headers = { 'User-Agent': self._user_agent, 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'Accept-Language': 'zh-CN,zh;q=0.9', 'Accept-Encoding': 'gzip, deflate', 'Referer': 'https://www.baidu.com/', 'Connection': 'keep-alive', } response = self.session.get(search_url, headers=headers, timeout=self.timeout) response.encoding = 'utf-8' response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # 检测验证码 if self._is_captcha_page(response.text, soup): logger.warning("⚠️ 百度触发验证,跳过此引擎") return [] # 百度搜索结果选择器(多种尝试) result_items = soup.select('.result.c-container, .c-container, div[class*="result"]') logger.info(f"📊 百度HTML解析: 结果容器={len(result_items)}个") for result in result_items[:num_results]: try: # 提取标题和链接 title_elem = result.select_one('h3 a, .t a, a[href]') if not title_elem: continue title = title_elem.get_text(strip=True) url = title_elem.get('href', '') # 百度使用跳转链接,需要提取真实URL # 但通常跳转链接也能用 # 提取摘要 snippet_elem = result.select_one('.c-abstract, .c-span-last, .content-right_8Zs40') snippet = snippet_elem.get_text(strip=True) if snippet_elem else '' if url and title and 'baidu.com' not in url: results.append({ "url": url, "title": title, "snippet": snippet[:300], "source": "baidu" }) logger.debug(f" ✅ 提取: {title[:40]}") except Exception as e: logger.debug(f"解析百度结果失败: {e}") continue # 备选方法:从所有标题链接提取 if not results: logger.info("⚠️ 百度标准选择器无结果,尝试提取 h3 链接...") h3_links = soup.select('h3 a') for link in h3_links[:num_results]: href = link.get('href', '') text = link.get_text(strip=True) if not href or not text or len(text) < 3: continue if href in [r['url'] for r in results]: continue results.append({ "url": href, "title": text[:100], "snippet": "", "source": "baidu" }) if len(results) >= num_results: break logger.info(f"✅ 百度搜索完成,获得 {len(results)} 条结果") except Exception as e: logger.warning(f"⚠️ 百度搜索失败: {e}") return results def search_on_baidu_news( self, query: str, num_results: int = 10 ) -> List[Dict[str, str]]: """ 在百度新闻搜索(news.baidu.com)获取新闻结果 使用 news.baidu.com 入口,返回的 URL 是真实的第三方新闻链接, 不是百度跳转链接,避免乱码问题。 Args: query: 搜索关键词 num_results: 获取的结果数量 Returns: 搜索结果列表 """ results = [] try: # 使用百度新闻入口(news.baidu.com),返回真实的第三方 URL search_url = f"https://news.baidu.com/ns?word={quote_plus(query)}&tn=news&from=news&cl=2&rn={num_results}&ct=1" logger.info(f"🔍 百度新闻搜索: {query}") logger.debug(f"搜索URL: {search_url}") # 百度需要特定的请求头 headers = { 'User-Agent': self._user_agent, 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'Accept-Language': 'zh-CN,zh;q=0.9', 'Accept-Encoding': 'gzip, deflate', 'Referer': 'https://news.baidu.com/', 'Connection': 'keep-alive', } response = self.session.get(search_url, headers=headers, timeout=self.timeout, allow_redirects=True) response.encoding = 'utf-8' response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # 检测验证码 if self._is_captcha_page(response.text, soup): logger.warning("⚠️ 百度新闻触发验证,跳过") return [] # 百度新闻搜索结果选择器 # 新闻标题在 h3 > a 中,链接是真实的第三方 URL news_h3_links = soup.select('h3 a[href^="http"]') logger.info(f"📊 百度新闻HTML解析: h3链接={len(news_h3_links)}个") for link in news_h3_links[:num_results * 2]: # 多取一些,后面过滤 try: url = link.get('href', '') title = link.get_text(strip=True) # 清理标题(去掉"标题:"前缀) if title.startswith('标题:'): title = title[3:] # 过滤无效结果 if not url or not title or len(title) < 5: continue # 过滤百度内部链接(但保留百家号 baijiahao.baidu.com) if 'baidu.com' in url and 'baijiahao.baidu.com' not in url: continue if url in [r['url'] for r in results]: continue # 去重 # 尝试找到父容器获取摘要 parent = link.find_parent(['div', 'li']) snippet = '' news_source = '' publish_time = '' if parent: # 提取摘要(通常在 generic 或 p 元素中) snippet_elem = parent.select_one('[class*="summary"], [class*="abstract"], p') if snippet_elem: snippet = snippet_elem.get_text(strip=True)[:300] # 提取来源(通常在包含"来源"的链接中) source_links = parent.select('a') for src_link in source_links: src_text = src_link.get_text(strip=True) if src_text and src_text != title[:20] and len(src_text) < 20: # 可能是来源(如"同花顺财经"、"新浪财经") if '新闻来源' in (src_link.get('aria-label', '') or ''): news_source = src_text break elif not news_source and not src_text.startswith('标题'): news_source = src_text results.append({ "url": url, "title": title, "snippet": snippet, "source": "baidu_news", "news_source": news_source # 新闻来源(如"同花顺财经") }) logger.debug(f" ✅ 新闻: {title[:40]} | {news_source}") if len(results) >= num_results: break except Exception as e: logger.debug(f"解析百度新闻结果失败: {e}") continue logger.info(f"✅ 百度新闻搜索完成,获得 {len(results)} 条新闻") except Exception as e: logger.warning(f"⚠️ 百度新闻搜索失败: {e}") return results def search_on_sogou( self, query: str, num_results: int = 10 ) -> List[Dict[str, str]]: """ 在搜狗上搜索并获取结果(备用搜索引擎) Args: query: 搜索关键词 num_results: 获取的结果数量 Returns: 搜索结果列表 """ results = [] try: # 构建搜狗搜索 URL search_url = f"https://www.sogou.com/web?query={quote_plus(query)}" logger.info(f"🔍 搜狗搜索: {query}") logger.debug(f"搜索URL: {search_url}") response = self.session.get(search_url, timeout=self.timeout) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # 检测验证码 if self._is_captcha_page(response.text, soup): logger.warning("⚠️ 搜狗触发验证,跳过此引擎") return [] # ========== 调试:打印找到的元素 ========== vrwrap_items = soup.select('.vrwrap, .rb, .results .vrwrap') logger.info(f"📊 搜狗HTML解析: .vrwrap/.rb={len(vrwrap_items)}个") # 搜狗搜索结果选择器 for result in vrwrap_items[:num_results]: try: title_elem = result.select_one('h3 a, .vr-title a, a[href]') if not title_elem: continue title = title_elem.get_text(strip=True) url = title_elem.get('href', '') snippet_elem = result.select_one('.str_info, .str-text, p, .txt-info') snippet = snippet_elem.get_text(strip=True) if snippet_elem else '' if url and title and 'sogou.com' not in url: results.append({ "url": url, "title": title, "snippet": snippet[:300], "source": "sogou" }) logger.debug(f" ✅ 提取: {title[:40]} -> {url[:60]}") except Exception as e: logger.debug(f"解析搜狗结果失败: {e}") continue # 备选方法:从页面链接提取 if not results: logger.info("⚠️ 搜狗标准选择器无结果,尝试从页面链接提取...") all_links = soup.select('a[href^="http"]') for link in all_links[:num_results * 3]: href = link.get('href', '') text = link.get_text(strip=True) if not href or not text or len(text) < 5: continue if 'sogou.com' in href: continue if href in [r['url'] for r in results]: continue results.append({ "url": href, "title": text[:100], "snippet": "", "source": "sogou" }) if len(results) >= num_results: break logger.info(f"✅ 搜狗搜索完成,获得 {len(results)} 条结果") except Exception as e: logger.warning(f"⚠️ 搜狗搜索失败: {e}") return results def search_on_360( self, query: str, num_results: int = 10 ) -> List[Dict[str, str]]: """ 在 360 搜索上搜索并获取结果 Args: query: 搜索关键词 num_results: 获取的结果数量 Returns: 搜索结果列表 """ results = [] try: # 构建 360 搜索 URL search_url = f"https://www.so.com/s?q={quote_plus(query)}" logger.info(f"🔍 360搜索: {query}") logger.debug(f"搜索URL: {search_url}") response = self.session.get(search_url, timeout=self.timeout) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # 检测验证码 if self._is_captcha_page(response.text, soup): logger.warning("⚠️ 360触发验证,跳过此引擎") return [] # ========== 调试:打印找到的元素 ========== res_items = soup.select('.res-list, .result, li.res-list') logger.info(f"📊 360 HTML解析: .res-list/.result={len(res_items)}个") # 360 搜索结果选择器 for result in res_items[:num_results]: try: title_elem = result.select_one('h3 a, .res-title a, a[href]') if not title_elem: continue title = title_elem.get_text(strip=True) url = title_elem.get('href', '') snippet_elem = result.select_one('.res-desc, p.res-summary, p, .res-comm-con') snippet = snippet_elem.get_text(strip=True) if snippet_elem else '' if url and title and 'so.com' not in url and '360.cn' not in url: results.append({ "url": url, "title": title, "snippet": snippet[:300], "source": "360" }) logger.debug(f" ✅ 提取: {title[:40]} -> {url[:60]}") except Exception as e: logger.debug(f"解析 360 结果失败: {e}") continue # 备选方法:从页面链接提取 if not results: logger.info("⚠️ 360 标准选择器无结果,尝试从页面链接提取...") all_links = soup.select('a[href^="http"]') for link in all_links[:num_results * 3]: href = link.get('href', '') text = link.get_text(strip=True) if not href or not text or len(text) < 5: continue if 'so.com' in href or '360.cn' in href: continue if href in [r['url'] for r in results]: continue results.append({ "url": href, "title": text[:100], "snippet": "", "source": "360" }) if len(results) >= num_results: break logger.info(f"✅ 360搜索完成,获得 {len(results)} 条结果") except Exception as e: logger.warning(f"⚠️ 360搜索失败: {e}") return results def interactive_search( self, query: str, engines: List[str] = None, num_results: int = 10, search_type: str = "news", # 新增参数:news(新闻)或 web(网页) **kwargs # 兼容旧接口 ) -> List[Dict[str, str]]: """ 使用多个搜索引擎进行搜索 Args: query: 搜索关键词 engines: 搜索引擎列表 ['baidu_news', 'baidu', 'sogou', '360', 'bing'] num_results: 每个引擎的结果数量 search_type: 搜索类型 'news'(新闻优先)或 'web'(网页) Returns: 合并的搜索结果 """ if engines is None: if search_type == "news": # 新闻搜索:优先使用百度资讯 engines = ["baidu_news", "sogou"] else: # 普通网页搜索 engines = ["baidu", "sogou"] all_results = [] engines_tried = [] for engine in engines: try: engine_lower = engine.lower() if engine_lower == "baidu_news": results = self.search_on_baidu_news(query, num_results) elif engine_lower == "baidu": results = self.search_on_baidu(query, num_results) elif engine_lower == "bing": results = self.search_on_bing(query, num_results) elif engine_lower == "sogou": results = self.search_on_sogou(query, num_results) elif engine_lower == "360": results = self.search_on_360(query, num_results) else: logger.warning(f"⚠️ 不支持的搜索引擎: {engine}") continue if results: all_results.extend(results) engines_tried.append(engine_lower) logger.info(f"✅ {engine} 返回 {len(results)} 条结果") else: logger.info(f"⚠️ {engine} 无结果或被拦截") # 搜索间隔,避免被封 if len(engines) > 1: time.sleep(random.uniform(0.8, 1.5)) except Exception as e: logger.error(f"❌ 使用 {engine} 搜索失败: {e}") continue # 如果所有引擎都失败了,尝试备用引擎 if not all_results: backup_engines = ["baidu_news", "360", "baidu", "sogou"] for backup in backup_engines: if backup not in [e.lower() for e in engines]: logger.info(f"🔄 尝试备用引擎: {backup}") try: if backup == "baidu_news": results = self.search_on_baidu_news(query, num_results) elif backup == "360": results = self.search_on_360(query, num_results) elif backup == "baidu": results = self.search_on_baidu(query, num_results) elif backup == "sogou": results = self.search_on_sogou(query, num_results) if results: all_results.extend(results) engines_tried.append(backup) logger.info(f"✅ 备用引擎 {backup} 返回 {len(results)} 条结果") break except Exception as e: logger.warning(f"备用引擎 {backup} 也失败: {e}") continue # 去重 seen_urls = set() unique_results = [] for r in all_results: if r["url"] not in seen_urls: seen_urls.add(r["url"]) unique_results.append(r) logger.info(f"交互式搜索完成: {len(all_results)} -> {len(unique_results)} (去重后), 使用引擎: {engines_tried}") return unique_results def crawl_page(self, url: str) -> Optional[Dict[str, Any]]: """ 爬取单个页面内容 Args: url: 页面 URL Returns: {"url": "...", "title": "...", "content": "...", "text": "...", "html": "..."} 或 None """ try: response = self.session.get(url, timeout=self.timeout) response.encoding = response.apparent_encoding or 'utf-8' # 保存原始 HTML(清理 NUL 字符) raw_html = response.text.replace('\x00', '').replace('\0', '') soup = BeautifulSoup(raw_html, 'html.parser') # 获取标题(在移除元素之前) title = '' title_elem = soup.find('title') if title_elem: title = title_elem.get_text(strip=True) # 尝试获取 h1 作为更好的标题 h1_elem = soup.find('h1') if h1_elem: h1_text = h1_elem.get_text(strip=True) if h1_text and len(h1_text) > 5: title = h1_text # 移除无关元素(用于提取正文) for elem in soup.select('script, style, iframe, nav, footer, header, aside, .ad, .advertisement, .comment, .sidebar'): elem.decompose() # 获取主要内容 # 优先选择 article, main, .content 等 main_content = None content_selectors = [ 'article', 'main', '.content', '.post-content', '.article-content', '#content', '.main-content', '.news-content', '.article-body', '.entry-content', '.post-body', '[itemprop="articleBody"]' ] for selector in content_selectors: main_content = soup.select_one(selector) if main_content: break if not main_content: main_content = soup.find('body') or soup # 提取文本 text_content = main_content.get_text(separator='\n', strip=True) # 清理文本 text_content = re.sub(r'\n{3,}', '\n\n', text_content) # 不再截断内容,保留完整正文(数据库字段应该支持长文本) # text_content = text_content[:5000] # 移除截断 logger.debug(f"📄 爬取完成: {title[:40]}... | 正文{len(text_content)}字符 | HTML{len(raw_html) if raw_html else 0}字符") return { "url": url, "title": title, "content": text_content, # 完整正文 "text": text_content, # 兼容字段 "html": raw_html if raw_html else None # 完整原始 HTML } except requests.exceptions.Timeout: logger.warning(f"⚠️ 爬取页面超时: {url[:60]}...") except Exception as e: logger.warning(f"⚠️ 爬取页面失败 {url[:60]}...: {e}") return None def crawl_search_results( self, search_results: List[Dict[str, str]], max_results: int = 5 ) -> List[Dict[str, Any]]: """ 爬取搜索结果中的页面内容 Args: search_results: 搜索结果列表 max_results: 最多爬取多少个页面 Returns: 爬取结果列表 [{"url": "...", "title": "...", "content": "..."}] """ crawled = [] for i, result in enumerate(search_results[:max_results]): url = result.get("url") if not url: continue logger.info(f"📄 爬取页面 {i+1}/{min(max_results, len(search_results))}: {url[:60]}...") page_data = self.crawl_page(url) if page_data and page_data.get("content"): page_data["snippet"] = result.get("snippet", "") page_data["source"] = result.get("source", "web") crawled.append(page_data) logger.debug(f"✅ 爬取成功: {page_data['title'][:50]}...") else: # 爬取失败时,使用搜索结果的摘要 crawled.append({ "url": url, "title": result.get("title", ""), "content": result.get("snippet", ""), "snippet": result.get("snippet", ""), "source": result.get("source", "web") }) logger.debug(f"⚠️ 使用摘要代替: {result.get('title', 'N/A')[:50]}...") # 爬取间隔 if i < max_results - 1: time.sleep(random.uniform(0.3, 0.8)) logger.info(f"📄 页面爬取完成: {len(crawled)} 个成功") return crawled # 便捷函数 def create_interactive_crawler(headless: bool = True, **kwargs) -> InteractiveCrawler: """创建交互式爬虫(兼容旧接口)""" return InteractiveCrawler() def search_and_crawl( query: str, engines: List[str] = None, max_search_results: int = 10, max_crawl_results: int = 5, **kwargs # 兼容旧接口 ) -> Dict[str, Any]: """ 一体化搜索和爬取函数 Args: query: 搜索关键词 engines: 搜索引擎列表 max_search_results: 最多获取多少个搜索结果 max_crawl_results: 最多爬取多少个页面 Returns: { "search_results": [...], "crawled_results": [...], "total_results": int } """ crawler = InteractiveCrawler() logger.info(f"🔍 开始搜索: {query}") search_results = crawler.interactive_search( query, engines=engines, num_results=max_search_results ) if not search_results: logger.warning(f"搜索未返回结果: {query}") return { "search_results": [], "crawled_results": [], "total_results": 0 } logger.info(f"📄 开始爬取前 {max_crawl_results} 个结果") crawled_results = crawler.crawl_search_results( search_results, max_results=max_crawl_results ) return { "search_results": search_results, "crawled_results": crawled_results, "total_results": len(crawled_results) } ================================================ FILE: backend/app/tools/jingji21_crawler.py ================================================ """ 21经济网爬虫工具 目标URL: https://www.21jingji.com/ (证券栏目) """ import re import logging from typing import List, Optional from datetime import datetime, timedelta from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class Jingji21CrawlerTool(BaseCrawler): """ 21经济网爬虫 主要爬取证券栏目 """ BASE_URL = "https://www.21jingji.com/" # 证券栏目URL STOCK_URL = "https://www.21jingji.com/channel/capital/" SOURCE_NAME = "jingji21" def __init__(self): super().__init__( name="jingji21_crawler", description="Crawl financial news from 21 Jingji (21jingji.com)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取21经济网新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled Jingji21, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling Jingji21: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: # 尝试爬取证券栏目或主页 try: response = self._fetch_page(self.STOCK_URL) except: response = self._fetch_page(self.BASE_URL) soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 21经济网新闻URL模式 if ('/article/' in href or '/html/' in href or '.shtml' in href) and title: # 确保是完整URL if not href.startswith('http'): href = 'https://www.21jingji.com' + href if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) # 确保编码正确:21经济网可能使用 gbk 编码 if '21jingji.com' in url: # 尝试多种编码 encodings = ['utf-8', 'gbk', 'gb2312', 'gb18030'] raw_html = None for enc in encodings: try: raw_html = response.content.decode(enc) # 验证是否包含中文字符(避免乱码) if '\u4e00' <= raw_html[0:100] <= '\u9fff' or any('\u4e00' <= c <= '\u9fff' for c in raw_html[:500]): break except (UnicodeDecodeError, LookupError): continue if raw_html is None: raw_html = response.text else: raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'article-content'}, {'class': 'content'}, {'class': 'text'}, {'id': 'content'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/jwview_crawler.py ================================================ """ 中新经纬爬虫工具 目标URL: https://www.jwview.com/ """ import re import logging from typing import List, Optional from datetime import datetime, timedelta from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class JwviewCrawlerTool(BaseCrawler): """ 中新经纬新闻爬虫 爬取中新经纬财经新闻 """ BASE_URL = "https://www.jwview.com/" # 股票/证券专栏URL(如果有) STOCK_URL = "https://www.jwview.com/jingwei/html/index.shtml" SOURCE_NAME = "jwview" def __init__(self): super().__init__( name="jwview_crawler", description="Crawl financial news from Zhongxin Jingwei (jwview.com)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取中新经纬新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled Jwview, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling Jwview: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: # 尝试爬取主页或股票专栏 response = self._fetch_page(self.BASE_URL) # 金融界可能使用 gbk 编码 if response.encoding == 'ISO-8859-1' or not response.encoding: try: response.content.decode('gbk') response.encoding = 'gbk' except: response.encoding = 'utf-8' soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接(中新经纬的URL模式) all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 中新经纬新闻URL模式 if ('/jingwei/' in href or '/html/' in href) and title: # 规范化 URL,避免出现 //www... 重复前缀 if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.jwview.com' + href elif not href.startswith('http'): href = 'https://www.jwview.com/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'content'}, {'class': 'article-content'}, {'class': 'text'}, {'id': 'content'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 处理相对时间 if '分钟前' in time_str: minutes = int(re.search(r'(\d+)', time_str).group(1)) return now - timedelta(minutes=minutes) elif '小时前' in time_str: hours = int(re.search(r'(\d+)', time_str).group(1)) return now - timedelta(hours=hours) elif '昨天' in time_str: return now - timedelta(days=1) # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/nbd_crawler.py ================================================ """ 每日经济新闻爬虫工具 目标URL: https://finance.nbd.com.cn/ """ import re import logging from typing import List, Optional from datetime import datetime from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class NbdCrawlerTool(BaseCrawler): """ 每日经济新闻爬虫 主要爬取财经股市新闻 """ BASE_URL = "https://www.nbd.com.cn/" STOCK_URL = "https://www.nbd.com.cn/columns/3/" SOURCE_NAME = "nbd" def __init__(self): super().__init__( name="nbd_crawler", description="Crawl financial news from NBD (nbd.com.cn)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取每日经济新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled NBD, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling NBD: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: response = self._fetch_page(self.STOCK_URL) soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: # 如果是503错误,记录但继续处理其他URL error_str = str(e) if '503' in error_str or 'Service Temporarily Unavailable' in error_str: logger.warning(f"Skipping {link_info.get('url', 'unknown')} due to 503 error (server overloaded)") else: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) # NBD新闻URL模式(扩展更多模式) nbd_patterns = [ '/articles/', # 文章列表 '/article/', # 文章 '.html', # HTML页面 '/columns/', # 栏目 '/finance/', # 财经 ] for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 检查是否匹配NBD URL模式 is_nbd_url = False # 方式1: 检查URL模式 for pattern in nbd_patterns: if pattern in href: is_nbd_url = True break # 方式2: 检查是否包含nbd.com.cn域名 if 'nbd.com.cn' in href: is_nbd_url = True # 方式3: 检查链接的class或data属性 if not is_nbd_url: link_class = link.get('class', []) if isinstance(link_class, list): link_class_str = ' '.join(link_class) else: link_class_str = str(link_class) if any(kw in link_class_str.lower() for kw in ['news', 'article', 'item', 'title', 'list']): if href.startswith('/') or 'nbd.com.cn' in href: is_nbd_url = True if is_nbd_url and title and len(title.strip()) > 5: # 确保是完整URL if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.nbd.com.cn' + href elif not href.startswith('http'): href = 'https://www.nbd.com.cn/' + href.lstrip('/') # 过滤掉明显不是新闻的链接 if any(skip in href.lower() for skip in ['javascript:', 'mailto:', '#', 'void(0)', '/tag/', '/author/', '/user/', '/login']): continue if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title.strip()}) logger.debug(f"NBD: Found {len(news_links)} potential news links") return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: # 检查是否是503错误(服务器过载) error_str = str(e) if '503' in error_str or 'Service Temporarily Unavailable' in error_str: logger.debug(f"Skipping {url} due to 503 error (server overloaded, will retry later)") # 对于503错误,直接返回None,不记录为警告,因为这是临时性问题 return None else: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" # 每经网站可能的正文容器选择器(按优先级排序) content_selectors = [ # 新版页面结构 {'class': 'article-body'}, {'class': 'article__body'}, {'class': 'article-text'}, {'class': 'content-article'}, {'class': 'main-content'}, # 旧版页面结构 {'class': 'g-article-content'}, {'class': 'article-content'}, {'class': 'content'}, {'id': 'contentText'}, {'id': 'article-content'}, # 通用选择器 {'itemprop': 'articleBody'}, ] for selector in content_selectors: content_div = soup.find(['div', 'article', 'section'], selector) if content_div: # 移除脚本、样式、广告等无关元素 for tag in content_div.find_all(['script', 'style', 'iframe', 'ins', 'noscript']): tag.decompose() for ad in content_div.find_all(class_=re.compile(r'ad|advertisement|banner|recommend')): ad.decompose() # 提取所有段落,不限制数量 paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content and len(content) > 50: return self._clean_text(content) # 如果没有 p 标签,直接取文本 text = content_div.get_text(separator='\n', strip=True) if text and len(text) > 50: return self._clean_text(text) # 后备方案:取所有段落(不限制数量) paragraphs = soup.find_all('p') if paragraphs: # 过滤掉可能的导航、页脚等短段落 valid_paragraphs = [ p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True) and len(p.get_text(strip=True)) > 10 ] content = '\n'.join(valid_paragraphs) if content: return self._clean_text(content) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date|pub')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source|editor')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/netease163_crawler.py ================================================ """ 网易财经爬虫工具 目标URL: https://money.163.com/ """ import re import logging from typing import List, Optional from datetime import datetime from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class Netease163CrawlerTool(BaseCrawler): """ 网易财经爬虫 主要爬取财经股市新闻 """ BASE_URL = "https://money.163.com/" STOCK_URL = "https://money.163.com/stock/" SOURCE_NAME = "163" def __init__(self): super().__init__( name="netease163_crawler", description="Crawl financial news from Netease Money (money.163.com)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取网易财经新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled 163, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling 163: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: # 尝试爬取股票栏目或主页 try: response = self._fetch_page(self.STOCK_URL) except: response = self._fetch_page(self.BASE_URL) soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 网易新闻URL模式 if ('money.163.com' in href or 'stock' in href) and title: # 确保是完整URL if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://money.163.com' + href elif not href.startswith('http'): href = 'https://money.163.com/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'post_text'}, {'id': 'endText'}, {'class': 'article-content'}, {'class': 'content'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('div', {'class': re.compile(r'post_time|time')}) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if not author_elem: author_elem = soup.find('div', {'id': 'ne_article_source'}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/search_engine_crawler.py ================================================ """ 搜索引擎爬虫工具 直接爬取搜索引擎结果页面(Bing/Baidu) """ import logging import re import requests from typing import List, Dict, Any, Optional from datetime import datetime, timedelta from urllib.parse import quote_plus from bs4 import BeautifulSoup import time logger = logging.getLogger(__name__) class SearchEngineCrawler: """ 搜索引擎爬虫 直接爬取 Bing/Baidu 搜索结果 """ def __init__(self): """初始化搜索引擎爬虫""" self.headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', 'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8', 'Accept-Encoding': 'gzip, deflate', 'DNT': '1', 'Connection': 'keep-alive', 'Upgrade-Insecure-Requests': '1' } self.session = requests.Session() self.session.headers.update(self.headers) logger.info("🔧 搜索引擎爬虫已初始化") def _fetch_url(self, url: str, timeout: int = 10) -> Optional[str]: """ 爬取URL内容 Args: url: 目标URL timeout: 超时时间 Returns: HTML内容 """ try: response = self.session.get(url, timeout=timeout) response.raise_for_status() # 尝试检测编码 if response.encoding == 'ISO-8859-1': # 对于中文网站,尝试使用 gb2312 或 utf-8 encodings = ['utf-8', 'gb2312', 'gbk'] for enc in encodings: try: response.encoding = enc _ = response.text break except: continue return response.text except Exception as e: logger.error(f"❌ 爬取失败 {url}: {e}") return None def search_with_engine( self, query: str, engine: str = "bing", days: int = 30, max_results: int = 50 ) -> List[Dict[str, Any]]: """ 使用搜索引擎搜索新闻 Args: query: 搜索关键词 engine: 搜索引擎 (bing/baidu) days: 时间范围(天) max_results: 最大结果数 Returns: 新闻列表 """ if engine not in self.search_engines: logger.error(f"❌ 不支持的搜索引擎: {engine}") return [] # 构建搜索URL search_query = self._build_search_query(query, days) search_url = self.search_engines[engine].format(query=quote_plus(search_query)) logger.info(f"🔍 搜索引擎爬取: {engine} - {search_query}") logger.info(f" URL: {search_url}") # 创建临时输出目录 with tempfile.TemporaryDirectory() as temp_dir: # 爬取搜索结果页面 result = self._call_mcp_crawl(search_url, temp_dir) if not result: logger.warning(f"⚠️ 搜索引擎爬取失败: {search_url}") return [] # 解析搜索结果 news_items = self._parse_search_results( content=result.get("content", ""), engine=engine, max_results=max_results ) logger.info(f"✅ 从 {engine} 提取到 {len(news_items)} 条结果") return news_items def _build_search_query(self, query: str, days: int) -> str: """ 构建搜索查询字符串(添加时间限制) Args: query: 原始查询 days: 时间范围 Returns: 增强的搜索查询 """ # 添加时间范围(对于 Bing 和 Baidu) # Bing: 支持 "query site:xxx.com" # 可以添加新闻源限制 # 可选:限制到新闻网站 news_sites = [ "sina.com.cn", "163.com", "eastmoney.com", "cnstock.com", "stcn.com", "caijing.com.cn", "yicai.com", ] # 构建基础查询 enhanced_query = f"{query} 新闻" # 添加时间提示词 if days <= 7: enhanced_query += " 最近一周" elif days <= 30: enhanced_query += " 最近一个月" return enhanced_query def _parse_search_results( self, content: str, engine: str, max_results: int ) -> List[Dict[str, Any]]: """ 解析搜索引擎返回的内容,提取新闻链接和标题 Args: content: 爬取的页面内容(Markdown格式) engine: 搜索引擎类型 max_results: 最大结果数 Returns: 新闻条目列表 """ news_items = [] # 从 Markdown 内容中提取链接 # 格式:[标题](URL) link_pattern = r'\[([^\]]+)\]\(([^\)]+)\)' matches = re.findall(link_pattern, content) for title, url in matches[:max_results]: # 过滤掉搜索引擎自身的链接 if engine in url.lower(): continue # 过滤掉非新闻链接 if not self._is_news_url(url): continue news_items.append({ "title": title.strip(), "url": url.strip(), "snippet": "", # 暂时为空,后续可以从 content 中提取 "source": self._extract_source_from_url(url), "engine": engine }) return news_items def _is_news_url(self, url: str) -> bool: """判断是否为新闻URL""" news_domains = [ "sina.com", "163.com", "eastmoney.com", "cnstock.com", "stcn.com", "caijing.com", "yicai.com", "nbd.com", "jwview.com", "eeo.com.cn", "finance.qq.com" ] return any(domain in url.lower() for domain in news_domains) def _extract_source_from_url(self, url: str) -> str: """从URL提取来源""" domain_mapping = { "sina.com": "新浪财经", "163.com": "网易财经", "eastmoney.com": "东方财富", "cnstock.com": "中国证券网", "stcn.com": "证券时报", "caijing.com": "财经网", "yicai.com": "第一财经", "nbd.com": "每日经济新闻", "jwview.com": "金融界", "eeo.com.cn": "经济观察网", "qq.com": "腾讯财经", } for domain, source in domain_mapping.items(): if domain in url.lower(): return source return "未知来源" def search_stock_news( self, stock_name: str, stock_code: str, days: int = 30, engines: Optional[List[str]] = None, max_per_engine: int = 30 ) -> List[Dict[str, Any]]: """ 搜索股票新闻(多搜索引擎) Args: stock_name: 股票名称 stock_code: 股票代码 days: 时间范围 engines: 搜索引擎列表,默认 ["bing"] max_per_engine: 每个搜索引擎最大结果数 Returns: 新闻列表 """ if engines is None: engines = ["bing"] # 默认只用 Bing(Baidu 可能需要处理反爬) all_news = [] # 构建搜索关键词 queries = [ stock_name, f"{stock_name} {stock_code}", f"{stock_name} 公告", ] for engine in engines: for query in queries: try: news = self.search_with_engine( query=query, engine=engine, days=days, max_results=max_per_engine ) all_news.extend(news) except Exception as e: logger.error(f"❌ 搜索失败 [{engine}] {query}: {e}") # 去重(按URL) seen_urls = set() unique_news = [] for news in all_news: url = news.get("url") if url and url not in seen_urls: seen_urls.add(url) unique_news.append(news) logger.info(f"✅ 多引擎搜索完成: 总计 {len(unique_news)} 条(去重后)") return unique_news # 便捷函数 def create_search_engine_crawler(mcp_server_path: Optional[str] = None) -> SearchEngineCrawler: """创建搜索引擎爬虫实例""" return SearchEngineCrawler(mcp_server_path) # 测试代码 if __name__ == "__main__": logging.basicConfig(level=logging.INFO) crawler = create_search_engine_crawler() # 测试搜索 results = crawler.search_stock_news( stock_name="深振业A", stock_code="000006", days=7, engines=["bing"], max_per_engine=10 ) print(f"\n✅ 搜索到 {len(results)} 条新闻:") for i, news in enumerate(results[:5], 1): print(f"{i}. {news['title']}") print(f" 来源: {news['source']}") print(f" URL: {news['url']}") ================================================ FILE: backend/app/tools/sina_crawler.py ================================================ """ 新浪财经爬虫工具 重构自 legacy_v1/Crawler/crawler_sina.py """ import re import logging from typing import List, Optional from datetime import datetime from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class SinaCrawlerTool(BaseCrawler): """ 新浪财经新闻爬虫 爬取最新滚动新闻页面 """ # 新浪财经最新滚动新闻页面(2024年后的新URL) BASE_URL = "https://finance.sina.com.cn/roll/c/56592.shtml" # 暂不支持翻页,只爬首页 SOURCE_NAME = "sina" def __init__(self): super().__init__( name="sina_finance_crawler", description="Crawl financial news from Sina Finance (sina.com.cn)" ) self.min_chinese_ratio = 0.5 # 最小中文比例阈值 def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取新浪财经新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] for page in range(start_page, end_page + 1): try: page_news = self._crawl_page(page) news_list.extend(page_news) logger.info(f"Crawled page {page}, got {len(page_news)} news items") except Exception as e: logger.error(f"Failed to crawl page {page}: {e}") continue return news_list def _crawl_page(self, page: int) -> List[NewsItem]: """ 爬取单页新闻列表 Args: page: 页码(目前只支持首页,忽略此参数) Returns: 新闻列表 """ url = self.BASE_URL # 新URL不支持翻页,只爬首页 logger.info(f"Fetching page: {url}") response = self._fetch_page(url) # 设置正确的编码 response.encoding = 'utf-8' soup = self._parse_html(response.text) # 查找新闻链接(改进选择器,更精确地找到新闻链接) news_links = [] for link in soup.find_all('a', href=True): href = link.get('href', '') # 匹配新浪财经股票相关新闻URL if 'finance.sina.com.cn' in href and ('/stock/' in href or '/roll/' in href): # 确保是完整的URL if href.startswith('http'): news_links.append(href) elif href.startswith('//'): news_links.append('http:' + href) # 去重 news_links = list(set(news_links)) logger.info(f"Found {len(news_links)} news links on page {page}") # 爬取每条新闻详情(限制每页最多50条,避免超时) news_list = [] max_news_per_page = 50 if page == 1 else 30 # 第一页爬取更多,其他页少一些 for idx, news_url in enumerate(news_links[:max_news_per_page], 1): try: logger.debug(f"Crawling news {idx}/{min(len(news_links), max_news_per_page)}: {news_url}") news_item = self._crawl_news_detail(news_url) if news_item: news_list.append(news_item) logger.debug(f"Successfully crawled: {news_item.title[:50]}") except Exception as e: logger.warning(f"Failed to crawl news detail {news_url}: {e}") continue logger.info(f"Successfully crawled {len(news_list)} news items from page {page}") return news_list def _crawl_news_detail(self, url: str) -> Optional[NewsItem]: """ 爬取新闻详情页 Args: url: 新闻URL Returns: 新闻项或None """ try: response = self._fetch_page(url) response.encoding = BeautifulSoup(response.content, "lxml").original_encoding raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取标题 title = self._extract_title(soup) if not title: return None # 提取摘要和关键词 summary, keywords = self._extract_meta(soup) # 提取发布时间 publish_time = self._extract_date(soup) # 提取关联股票代码 stock_codes = self._extract_stock_codes(soup) # 提取正文 content = self._extract_content(soup) if not content or len(content) < 50: return None return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, summary=summary, keywords=keywords, stock_codes=stock_codes, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.error(f"Error crawling {url}: {e}") return None def _extract_title(self, soup: BeautifulSoup) -> Optional[str]: """提取标题""" # 尝试多个可能的标题位置 title_tag = soup.find('h1', class_='main-title') if not title_tag: title_tag = soup.find('h1') if not title_tag: title_tag = soup.find('title') if title_tag: title = title_tag.get_text().strip() # 移除来源信息 title = re.sub(r'[-_].*?(新浪|财经|网)', '', title) return title.strip() return None def _extract_meta(self, soup: BeautifulSoup) -> tuple: """提取元数据(摘要和关键词)""" summary = "" keywords = [] for meta in soup.find_all('meta'): name = meta.get('name', '').lower() content = meta.get('content', '') if name == 'description': summary = content elif name == 'keywords': keywords = [kw.strip() for kw in content.split(',') if kw.strip()] return summary, keywords def _extract_date(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" # 查找时间标签 for span in soup.find_all('span'): # 检查 class 属性 class_attr = span.get('class', []) if 'date' in class_attr or 'time-source' in class_attr: date_text = span.get_text() return self._parse_date(date_text) # 检查 id 属性 if span.get('id') == 'pub_date': date_text = span.get_text() return self._parse_date(date_text) return None def _parse_date(self, date_text: str) -> Optional[datetime]: """解析日期字符串""" try: # 格式:2024年12月01日 10:30 date_text = date_text.strip() date_text = date_text.replace('年', '-').replace('月', '-').replace('日', '') # 尝试多种格式 for fmt in [ '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d', ]: try: return datetime.strptime(date_text.strip(), fmt) except ValueError: continue except Exception: pass return None def _extract_stock_codes(self, soup: BeautifulSoup) -> List[str]: """提取关联股票代码""" stock_codes = [] for span in soup.find_all('span'): span_id = span.get('id', '') if span_id.startswith('stock_'): # 格式:stock_sh600519 code = span_id[6:] # 移除 'stock_' 前缀 if code: stock_codes.append(code.upper()) return list(set(stock_codes)) def _extract_content(self, soup: BeautifulSoup) -> str: """提取正文内容""" # 尝试使用更精确的选择器 content_selectors = [ {'id': 'artibody'}, {'class': 'article-content'}, {'class': 'article'}, {'id': 'article'}, ] for selector in content_selectors: content_div = soup.find(['div', 'article'], selector) if content_div: # 1. 移除明确的噪音元素 for tag in content_div.find_all(['script', 'style', 'iframe', 'ins', 'select', 'input', 'button', 'form']): tag.decompose() # 2. 移除特定的广告和推荐块 for ad in content_div.find_all(class_=re.compile(r'ad|banner|share|otherContent|recommend|app-guide', re.I)): ad.decompose() # 3. 获取所有文本,使用换行符分隔 # 关键修改:使用 get_text 而不是 find_all('p'),确保不漏掉裸露的文本节点 full_text = content_div.get_text(separator='\n', strip=True) # 4. 按行分割并清洗 lines = full_text.split('\n') article_parts = [] for line in lines: line = line.strip() if not line: continue # 5. 过滤和清洗行 # 检查中文比例 chinese_ratio = self._extract_chinese_ratio(line) # 宽松的保留策略: # - 忽略极短的非中文行(可能是页码、特殊符号) if len(line) < 2: continue # 保留条件: # 1. 包含一定比例中文(>5%) # 2. 或者长文本(>20字符),可能是纯数据或英文段落 if chinese_ratio > 0.05 or len(line) > 20: clean_line = self._clean_text(line) if clean_line and not self._is_noise_text(clean_line): article_parts.append(clean_line) if article_parts: return '\n'.join(article_parts) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) def _is_noise_text(self, text: str) -> bool: """判断是否为噪音文本(广告、版权等)""" noise_patterns = [ r'^责任编辑', r'^编辑[::]', r'^来源[::]', r'^声明[::]', r'^免责声明', r'^版权', r'^copyright', r'^点击进入', r'^相关阅读', r'^延伸阅读', r'^\s*$', r'登录新浪财经APP', r'搜索【信披】', r'缩小字体', r'放大字体', r'收藏', r'微博', r'微信', r'分享', r'腾讯QQ', ] text_lower = text.lower().strip() for pattern in noise_patterns: if re.match(pattern, text_lower, re.I) or re.search(pattern, text_lower, re.I): return True return False # 便捷创建函数 def create_sina_crawler() -> SinaCrawlerTool: """创建新浪财经爬虫实例""" return SinaCrawlerTool() ================================================ FILE: backend/app/tools/tencent_crawler.py ================================================ """ 腾讯财经爬虫工具 目标URL: https://news.qq.com/ch/finance/ """ import re import logging from typing import List, Optional from datetime import datetime, timedelta from bs4 import BeautifulSoup import json from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class TencentCrawlerTool(BaseCrawler): """ 腾讯财经新闻爬虫 爬取腾讯财经频道最新新闻 """ BASE_URL = "https://news.qq.com/ch/finance_stock/" # 腾讯新闻API(如果页面动态加载,可能需要调用API) API_URL = "https://pacaio.match.qq.com/irs/rcd" SOURCE_NAME = "tencent" def __init__(self): super().__init__( name="tencent_finance_crawler", description="Crawl financial news from Tencent Finance (news.qq.com)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取腾讯财经新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: # 腾讯财经页面只爬取首页 page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled Tencent Finance, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling Tencent Finance: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """ 爬取单页新闻 优先使用API获取新闻,如果API失败则回退到HTML解析 Args: page: 页码 Returns: 新闻列表 """ news_items = [] # 先尝试使用API获取新闻 try: logger.info(f"[Tencent] Attempting API fetch for page {page}") api_news = self._fetch_api_news(page) logger.info(f"[Tencent] API returned {len(api_news) if api_news else 0} news items") if api_news: logger.info(f"Fetched {len(api_news)} news from API") for news_data in api_news[:20]: # 限制20条 try: news_item = self._parse_api_news_item(news_data) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to parse API news item: {e}") continue if news_items: logger.info(f"[Tencent] Successfully parsed {len(news_items)} news items from API") return news_items else: logger.info(f"[Tencent] API returned empty list, falling back to HTML") except Exception as e: logger.warning(f"API fetch failed, fallback to HTML: {e}") # API失败,回退到HTML解析 try: response = self._fetch_page(self.BASE_URL) # 腾讯新闻可能使用动态加载,确保编码正确 if response.encoding == 'ISO-8859-1' or not response.encoding: response.encoding = 'utf-8' soup = self._parse_html(response.text) # 提取新闻列表 # 腾讯的新闻可能在各种容器中,尝试提取所有新闻链接 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links from HTML") # 限制爬取数量,避免过多请求 max_news = 20 for i, link_info in enumerate(news_links[:max_news]): try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item {i+1}: {e}") continue except Exception as e: logger.error(f"Error crawling page {page}: {e}") return news_items def _fetch_api_news(self, page: int = 0) -> List[dict]: """ 通过API获取新闻列表 Args: page: 页码(从0开始) Returns: 新闻列表 """ try: # 腾讯新闻API参数(根据实际API文档调整) params = { "cid": "finance_stock", # 股票频道 "page": page, "num": 20, # 每页20条 "ext": "finance_stock", # 扩展参数 } headers = { "User-Agent": self.user_agent, "Referer": self.BASE_URL, "Accept": "application/json, text/javascript, */*; q=0.01", } logger.info(f"[Tencent] Calling API: {self.API_URL} with params: {params}") response = self.session.get( self.API_URL, params=params, headers=headers, timeout=self.timeout ) logger.info(f"[Tencent] API response status: {response.status_code}") response.raise_for_status() # 解析JSON响应(可能是JSONP格式) content = response.text.strip() logger.info(f"[Tencent] API response preview (first 500 chars): {content[:500]}") # 尝试解析JSONP格式 if content.startswith('callback(') or content.startswith('jQuery'): # 提取JSON部分 import re json_match = re.search(r'\((.*)\)$', content) if json_match: content = json_match.group(1) data = json.loads(content) logger.info(f"[Tencent] Parsed API response type: {type(data)}, keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}") if isinstance(data, dict): if 'data' in data: logger.info(f"[Tencent] Found 'data' key with {len(data['data']) if isinstance(data['data'], list) else 'non-list'} items") return data['data'] elif 'list' in data: logger.info(f"[Tencent] Found 'list' key with {len(data['list']) if isinstance(data['list'], list) else 'non-list'} items") return data['list'] elif 'result' in data: logger.info(f"[Tencent] Found 'result' key with {len(data['result']) if isinstance(data['result'], list) else 'non-list'} items") return data['result'] else: logger.warning(f"[Tencent] Unexpected API response format, keys: {list(data.keys())}") elif isinstance(data, list): logger.info(f"[Tencent] API returned list with {len(data)} items") return data logger.warning(f"Unexpected API response format: {type(data)}") return [] except json.JSONDecodeError as e: logger.warning(f"API JSON decode failed: {e}, response preview: {response.text[:200] if 'response' in locals() else 'N/A'}") return [] except Exception as e: logger.warning(f"API fetch failed: {e}") return [] def _parse_api_news_item(self, news_data: dict) -> Optional[NewsItem]: """ 解析API返回的新闻数据 Args: news_data: API返回的单条新闻数据 Returns: NewsItem对象 """ try: # 提取基本信息 title = news_data.get('title', '').strip() url = news_data.get('url', '') or news_data.get('surl', '') # 确保URL是完整的 if url and not url.startswith('http'): if url.startswith('//'): url = 'https:' + url elif url.startswith('/'): url = 'https://news.qq.com' + url else: url = 'https://news.qq.com/' + url.lstrip('/') if not title or not url: return None # 提取发布时间 publish_time_str = news_data.get('time', '') or news_data.get('publish_time', '') publish_time = self._parse_time_string(publish_time_str) if publish_time_str else datetime.now() # 提取摘要作为内容(API通常不返回完整内容) content = news_data.get('abstract', '') or news_data.get('intro', '') or title # 提取作者 author = news_data.get('author', '') or news_data.get('source', '') # 尝试获取完整内容 try: response = self._fetch_page(url) if response.encoding == 'ISO-8859-1' or not response.encoding: response.encoding = 'utf-8' raw_html = response.text soup = self._parse_html(raw_html) full_content = self._extract_content(soup) if full_content and len(full_content) > len(content): content = full_content except Exception as e: logger.debug(f"Failed to fetch full content from {url}: {e}") raw_html = None return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author if author else None, raw_html=raw_html, ) except Exception as e: logger.warning(f"Failed to parse API news item: {e}") return None def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """ 从页面中提取新闻链接 Args: soup: BeautifulSoup对象 Returns: 新闻链接信息列表 """ news_links = [] # 查找所有链接 all_links = soup.find_all('a', href=True) # 腾讯新闻URL模式(扩展更多模式) tencent_patterns = [ '/rain/a/', # 旧模式 '/omn/', # 旧模式 '/a/', # 新模式 '/finance/', # 财经频道 'finance.qq.com', # 财经域名 '/stock/', # 股票相关 ] for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 检查是否匹配腾讯新闻URL模式 is_tencent_url = False for pattern in tencent_patterns: if pattern in href: is_tencent_url = True break # 或者检查是否是qq.com域名且包含新闻相关关键词 if not is_tencent_url: if 'qq.com' in href and any(kw in href for kw in ['/a/', '/article/', '/news/', '/finance/']): is_tencent_url = True if is_tencent_url and title and len(title.strip()) > 5: # 确保是完整URL if not href.startswith('http'): if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://news.qq.com' + href else: href = 'https://news.qq.com/' + href.lstrip('/') # 过滤掉明显不是新闻的链接 if any(skip in href.lower() for skip in ['javascript:', 'mailto:', '#', 'void(0)']): continue if href not in [n['url'] for n in news_links]: news_links.append({ 'url': href, 'title': title.strip() }) logger.debug(f"Tencent: Found {len(news_links)} potential news links") return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """ 提取单条新闻详情 Args: link_info: 新闻链接信息 Returns: NewsItem或None """ url = link_info['url'] title = link_info['title'] try: # 获取新闻详情页 response = self._fetch_page(url) # 确保编码正确 if response.encoding == 'ISO-8859-1' or not response.encoding: response.encoding = 'utf-8' raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文内容 content = self._extract_content(soup) if not content: logger.debug(f"No content found for: {title}") return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """ 提取新闻正文 Args: soup: BeautifulSoup对象 Returns: 新闻正文 """ # 尝试多种选择器 content_selectors = [ {'class': 'content-article'}, {'class': 'LEFT'}, {'id': 'Cnt-Main-Article-QQ'}, {'class': 'article'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: # 获取所有段落 paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """ 提取发布时间 Args: soup: BeautifulSoup对象 Returns: 发布时间 """ try: # 尝试多种时间选择器 time_selectors = [ {'class': 'a-time'}, {'class': 'article-time'}, {'class': 'time'}, ] for selector in time_selectors: time_elem = soup.find('span', selector) if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) # 尝试从meta标签获取 meta_time = soup.find('meta', {'property': 'article:published_time'}) if meta_time and meta_time.get('content'): return datetime.fromisoformat(meta_time['content'].replace('Z', '+00:00')) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") # 默认返回当前时间 return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """ 解析时间字符串(如"1小时前"、"昨天"、"2024-12-06 10:00") Args: time_str: 时间字符串 Returns: datetime对象 """ now = datetime.now() # 处理相对时间 if '分钟前' in time_str: minutes = int(re.search(r'(\d+)', time_str).group(1)) return now - timedelta(minutes=minutes) elif '小时前' in time_str: hours = int(re.search(r'(\d+)', time_str).group(1)) return now - timedelta(hours=hours) elif '昨天' in time_str: return now - timedelta(days=1) elif '前天' in time_str: return now - timedelta(days=2) # 尝试解析绝对时间 try: # 尝试多种格式 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue except Exception: pass # 默认返回当前时间 return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """ 提取作者 Args: soup: BeautifulSoup对象 Returns: 作者名称 """ try: # 尝试多种作者选择器 author_selectors = [ {'class': 'author'}, {'class': 'article-author'}, {'class': 'source'}, ] for selector in author_selectors: author_elem = soup.find('span', selector) or soup.find('a', selector) if author_elem: author = author_elem.get_text(strip=True) if author: return author except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/app/tools/text_cleaner.py ================================================ """ 文本清洗工具 重构自 legacy_v1/src/Killua/ """ import re import logging from typing import List, Set import jieba from agenticx import BaseTool from agenticx.core import ToolMetadata, ToolCategory logger = logging.getLogger(__name__) class TextCleanerTool(BaseTool): """ 文本清洗工具 提供去停用词、分词、文本标准化等功能 """ # 中文停用词列表(简化版) STOP_WORDS = { '的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这', '那', '里', '就是', '什么', '可以', '为', '以', '及', '等', '将', '并', '个', '与', '对', '如', '所', '于', '被', '由', '从', '而', '把', '让', '向', '却', '但', '或', '及', '但是', '然而', '因为', '所以', '如果', '虽然', '尽管', '无论', '不管', '只要', '除非', '、', ',', '。', ';', ':', '?', '!', '"', '"', ''', ''', '(', ')', '【', '】', '《', '》', '—', '…', '·', '~', '#', '@', '&', } def __init__(self): metadata = ToolMetadata( name="text_cleaner", description="Clean and preprocess Chinese financial text", category=ToolCategory.UTILITY, version="1.0.0" ) super().__init__(metadata=metadata) # 初始化jieba jieba.setLogLevel(logging.WARNING) # 加载金融领域自定义词典(可选) self._load_custom_dict() def _load_custom_dict(self): """加载自定义词典""" # 金融领域常用词 financial_words = [ '股票', '证券', '基金', '债券', '期货', '期权', '外汇', '上证指数', '深证成指', '创业板', '科创板', '涨停', '跌停', '停牌', '复牌', '退市', '上市', '市盈率', '市净率', '市值', '流通股', '限售股', '分红', '配股', '增发', '回购', '重组', '并购', '利好', '利空', '看多', '看空', '做多', '做空', '成交量', '换手率', '振幅', '量比', ] for word in financial_words: jieba.add_word(word) def clean_text(self, text: str) -> str: """ 基础文本清洗 Args: text: 原始文本 Returns: 清洗后的文本 """ if not text: return "" # 移除URL text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text) # 移除邮箱 text = re.sub(r'[\w\.-]+@[\w\.-]+\.\w+', '', text) # 移除特殊字符(保留中文、英文、数字) text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s\.\,\!\?\:\;\-\%\(\)]', '', text) # 统一空格 text = re.sub(r'\s+', ' ', text) return text.strip() def tokenize(self, text: str, remove_stopwords: bool = True) -> List[str]: """ 中文分词 Args: text: 文本 remove_stopwords: 是否去除停用词 Returns: 词语列表 """ # 分词 words = jieba.cut(text) # 过滤 result = [] for word in words: word = word.strip() if not word: continue # 去除停用词 if remove_stopwords and word in self.STOP_WORDS: continue # 去除单字符(除了一些特殊字如"涨"、"跌") if len(word) == 1 and not re.match(r'[\u4e00-\u9fa5]', word): continue result.append(word) return result def extract_keywords(self, text: str, top_k: int = 10) -> List[str]: """ 提取关键词 Args: text: 文本 top_k: 返回的关键词数量 Returns: 关键词列表 """ import jieba.analyse keywords = jieba.analyse.extract_tags( text, topK=top_k, withWeight=False ) return keywords def normalize_stock_code(self, code: str) -> str: """ 标准化股票代码 Args: code: 原始代码(如 sh600519, 600519, SH600519) Returns: 标准化代码(如 600519) """ code = code.upper().strip() # 移除市场前缀 code = re.sub(r'^(SH|SZ|HK)', '', code) return code def _setup_parameters(self): """设置工具参数(AgenticX 要求)""" # TextCleanerTool 的参数通过 execute 方法的 kwargs 传递 pass def execute(self, **kwargs) -> dict: """ 同步执行方法(AgenticX Tool 协议要求) Args: **kwargs: 参数字典 - text: 输入文本(必需) - operation: 操作类型(clean, tokenize, keywords),默认 "clean" - remove_stopwords: 是否去除停用词(仅用于 tokenize),默认 True - top_k: 关键词数量(仅用于 keywords),默认 10 Returns: 执行结果 """ text = kwargs.get("text", "") if not text: return {"success": False, "error": "Missing required parameter: text"} operation = kwargs.get("operation", "clean") if operation == "clean": result = self.clean_text(text) return {"success": True, "result": result} elif operation == "tokenize": remove_stopwords = kwargs.get("remove_stopwords", True) result = self.tokenize(text, remove_stopwords) return {"success": True, "result": result, "count": len(result)} elif operation == "keywords": top_k = kwargs.get("top_k", 10) result = self.extract_keywords(text, top_k) return {"success": True, "result": result} else: return {"success": False, "error": f"Unknown operation: {operation}"} async def aexecute(self, **kwargs) -> dict: """ 异步执行方法(AgenticX Tool 协议要求) 当前实现为同步执行的包装 Args: **kwargs: 参数字典 Returns: 执行结果 """ return self.execute(**kwargs) # 便捷创建函数 def create_text_cleaner() -> TextCleanerTool: """创建文本清洗工具实例""" return TextCleanerTool() ================================================ FILE: backend/app/tools/yicai_crawler.py ================================================ """ 第一财经爬虫工具 目标URL: https://www.yicai.com/news/gushi/ """ import re import logging from typing import List, Optional from datetime import datetime from bs4 import BeautifulSoup from .crawler_base import BaseCrawler, NewsItem logger = logging.getLogger(__name__) class YicaiCrawlerTool(BaseCrawler): """ 第一财经爬虫 主要爬取股市新闻 """ BASE_URL = "https://www.yicai.com/" STOCK_URL = "https://www.yicai.com/news/gushi/" SOURCE_NAME = "yicai" def __init__(self): super().__init__( name="yicai_crawler", description="Crawl financial news from Yicai (yicai.com)" ) def crawl(self, start_page: int = 1, end_page: int = 1) -> List[NewsItem]: """ 爬取第一财经新闻 Args: start_page: 起始页码 end_page: 结束页码 Returns: 新闻列表 """ news_list = [] try: page_news = self._crawl_page(1) news_list.extend(page_news) logger.info(f"Crawled Yicai, got {len(page_news)} news items") except Exception as e: logger.error(f"Error crawling Yicai: {e}") # 应用股票筛选 filtered_news = self._filter_stock_news(news_list) return filtered_news def _crawl_page(self, page: int) -> List[NewsItem]: """爬取单页新闻""" news_items = [] try: response = self._fetch_page(self.STOCK_URL) soup = self._parse_html(response.text) # 提取新闻列表 news_links = self._extract_news_links(soup) logger.info(f"Found {len(news_links)} potential news links") # 限制爬取数量 max_news = 20 for link_info in news_links[:max_news]: try: news_item = self._extract_news_item(link_info) if news_item: news_items.append(news_item) except Exception as e: logger.warning(f"Failed to extract news item: {e}") continue except Exception as e: logger.error(f"Error crawling page: {e}") return news_items def _extract_news_links(self, soup: BeautifulSoup) -> List[dict]: """从页面中提取新闻链接""" news_links = [] # 查找新闻链接 all_links = soup.find_all('a', href=True) for link in all_links: href = link.get('href', '') title = link.get_text(strip=True) # 第一财经新闻URL模式 if ('/news/' in href or '/article/' in href) and title: # 确保是完整URL if href.startswith('//'): href = 'https:' + href elif href.startswith('/'): href = 'https://www.yicai.com' + href elif not href.startswith('http'): href = 'https://www.yicai.com/' + href.lstrip('/') if href not in [n['url'] for n in news_links]: news_links.append({'url': href, 'title': title}) return news_links def _extract_news_item(self, link_info: dict) -> Optional[NewsItem]: """提取单条新闻详情""" url = link_info['url'] title = link_info['title'] try: response = self._fetch_page(url) raw_html = response.text # 保存原始 HTML soup = self._parse_html(raw_html) # 提取正文 content = self._extract_content(soup) if not content: return None # 提取发布时间 publish_time = self._extract_publish_time(soup) # 提取作者 author = self._extract_author(soup) return NewsItem( title=title, content=content, url=url, source=self.SOURCE_NAME, publish_time=publish_time, author=author, raw_html=raw_html, # 保存原始 HTML ) except Exception as e: logger.warning(f"Failed to extract news from {url}: {e}") return None def _extract_content(self, soup: BeautifulSoup) -> str: """提取新闻正文""" content_selectors = [ {'class': 'm-txt'}, {'class': 'article-content'}, {'class': 'content'}, {'class': 'newsContent'}, ] for selector in content_selectors: content_div = soup.find('div', selector) if content_div: paragraphs = content_div.find_all('p') if paragraphs: content = '\n'.join([p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)]) if content: return self._clean_text(content) # 后备方案:使用基类的智能提取方法 return self._extract_article_content(soup) return "" def _extract_publish_time(self, soup: BeautifulSoup) -> Optional[datetime]: """提取发布时间""" try: time_elem = soup.find('span', {'class': re.compile(r'time|date')}) if not time_elem: time_elem = soup.find('time') if time_elem: time_str = time_elem.get_text(strip=True) return self._parse_time_string(time_str) except Exception as e: logger.debug(f"Failed to parse publish time: {e}") return datetime.now() def _parse_time_string(self, time_str: str) -> datetime: """解析时间字符串""" now = datetime.now() # 尝试解析绝对时间 formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M', '%Y-%m-%d', '%Y年%m月%d日 %H:%M', '%Y年%m月%d日', ] for fmt in formats: try: return datetime.strptime(time_str, fmt) except ValueError: continue return now def _extract_author(self, soup: BeautifulSoup) -> Optional[str]: """提取作者""" try: author_elem = soup.find('span', {'class': re.compile(r'author|source')}) if author_elem: return author_elem.get_text(strip=True) except Exception as e: logger.debug(f"Failed to extract author: {e}") return None ================================================ FILE: backend/clear_news_data.py ================================================ """ 清除所有新闻相关数据 """ import os import sys from pathlib import Path # 加载环境变量 from dotenv import load_dotenv env_path = Path(__file__).parent / ".env" load_dotenv(env_path) # 构建数据库 URL POSTGRES_USER = os.getenv("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost") POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432") POSTGRES_DB = os.getenv("POSTGRES_DB", "finnews_db") DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" from sqlalchemy import create_engine, text def clear_all_news_data(): """清除所有新闻相关数据""" print("🗑️ 正在清除所有新闻数据...") engine = create_engine(DATABASE_URL) with engine.connect() as conn: # 查询存在的表 result = conn.execute(text(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' """)) existing_tables = [row[0] for row in result.fetchall()] print(f" 数据库中的表: {existing_tables}") # 清除 news 表 if 'news' in existing_tables: result = conn.execute(text("SELECT COUNT(*) FROM news")) news_count = result.scalar() print(f" 当前新闻数量: {news_count}") conn.execute(text("TRUNCATE TABLE news RESTART IDENTITY CASCADE")) print(" ✅ news 表已清除") else: print(" ⚠️ news 表不存在") # 清除 news_analysis 表(如果存在) if 'news_analysis' in existing_tables: result = conn.execute(text("SELECT COUNT(*) FROM news_analysis")) analysis_count = result.scalar() print(f" 当前分析数量: {analysis_count}") conn.execute(text("TRUNCATE TABLE news_analysis RESTART IDENTITY CASCADE")) print(" ✅ news_analysis 表已清除") # 清除 analysis 表(如果存在) if 'analysis' in existing_tables: result = conn.execute(text("SELECT COUNT(*) FROM analysis")) analysis_count = result.scalar() print(f" 当前 analysis 数量: {analysis_count}") conn.execute(text("TRUNCATE TABLE analysis RESTART IDENTITY CASCADE")) print(" ✅ analysis 表已清除") conn.commit() print("\n✅ 所有新闻数据已清除!") if __name__ == "__main__": print("=" * 50) print("📰 FinnewsHunter - 清除新闻数据") print("=" * 50) # 确认操作 if len(sys.argv) > 1 and sys.argv[1] == "--yes": confirm = "y" else: confirm = input("\n⚠️ 确定要清除所有新闻数据吗?(y/N): ").strip().lower() if confirm == "y": clear_all_news_data() print("\n🎉 完成!") else: print("❌ 已取消操作") ================================================ FILE: backend/env.example ================================================ # FinnewsHunter 环境变量配置模板 # 复制此文件为 .env 并填入实际值 # ===== 应用配置 ===== APP_NAME=FinnewsHunter APP_VERSION=0.1.0 DEBUG=True # ===== 服务器配置 ===== HOST=0.0.0.0 PORT=8000 # ===== 数据库配置 ===== POSTGRES_USER=finnews POSTGRES_PASSWORD=finnews_dev_password POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_DB=finnews_db # ===== Redis 配置 ===== REDIS_HOST=localhost REDIS_PORT=6379 REDIS_DB=0 # REDIS_PASSWORD= # 可选,生产环境建议设置 # ===== Milvus 配置 ===== MILVUS_HOST=localhost MILVUS_PORT=19530 MILVUS_COLLECTION_NAME=finnews_embeddings # ⚠️ 重要:向量维度必须与 Embedding 模型匹配 # - OpenAI text-embedding-ada-002: 1536 维 # - 百炼 text-embedding-v4: 1024 维 MILVUS_DIM=1536 # ===== Neo4j 知识图谱配置 ===== NEO4J_URI=bolt://localhost:7687 NEO4J_USER=neo4j NEO4J_PASSWORD=finnews_neo4j_password # ========================================== # LLM 和 Embedding 配置 # ========================================== # 支持5个厂商:bailian、openai、deepseek、kimi、zhipu # 前端可以动态切换,后端需要配置对应的 API Key # ===== 默认LLM配置(可选,用于后端默认行为) ===== LLM_PROVIDER=bailian # 默认提供商 LLM_MODEL=qwen-plus # 默认模型 LLM_TEMPERATURE=0.7 LLM_MAX_TOKENS=2000 LLM_TIMEOUT=180 # LLM 调用超时时间(秒) # ========================================== # 各厂商 API Key 配置 # ========================================== # ⚠️ 注意:前端可以切换任意厂商,请配置所有需要使用的厂商的 API Key # ----- 1. 百炼(Bailian / 阿里云)----- # 获取地址:https://dashscope.console.aliyun.com/ DASHSCOPE_API_KEY=your-dashscope-api-key-here DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 # 可用模型列表(逗号分隔,可自定义添加新模型) BAILIAN_MODELS=qwen-plus,qwen-max,qwen-turbo,qwen-long # 百炼可选配置(如需使用 Agent 功能) # BAILIAN_ACCESS_KEY_ID=your-access-key-id # BAILIAN_ACCESS_KEY_SECRET=your-access-key-secret # BAILIAN_AGENT_CODE=your-agent-code # BAILIAN_REGION_ID=cn-beijing # ----- 2. OpenAI ----- # 获取地址:https://platform.openai.com/api-keys OPENAI_API_KEY=your-openai-api-key-here OPENAI_BASE_URL= # 留空使用官方 API,或填写代理地址 # 可用模型列表(逗号分隔,可自定义添加新模型) OPENAI_MODELS=gpt-4,gpt-4-turbo,gpt-3.5-turbo # ----- 3. DeepSeek ----- # 获取地址:https://platform.deepseek.com/api_keys DEEPSEEK_API_KEY=your-deepseek-api-key-here DEEPSEEK_BASE_URL=https://api.deepseek.com/v1 # 默认值,可不填 # 可用模型列表(逗号分隔,可自定义添加新模型) DEEPSEEK_MODELS=deepseek-chat,deepseek-coder # ----- 4. Kimi (Moonshot) ----- # 获取地址:https://platform.moonshot.cn/console/api-keys MOONSHOT_API_KEY=your-moonshot-api-key-here MOONSHOT_BASE_URL=https://api.moonshot.cn/v1 # 默认值,可不填 # 可用模型列表(逗号分隔,可自定义添加新模型) MOONSHOT_MODELS=moonshot-v1-8k,moonshot-v1-32k,moonshot-v1-128k # ----- 5. 智谱 (Zhipu AI) ----- # 获取地址:https://open.bigmodel.cn/usercenter/apikeys ZHIPU_API_KEY=your-zhipu-api-key-here ZHIPU_BASE_URL=https://open.bigmodel.cn/api/paas/v4 # 默认值,可不填 # 可用模型列表(逗号分隔,可自定义添加新模型) ZHIPU_MODELS=glm-4,glm-4-plus,glm-4-air,glm-3-turbo # ----- 6. BochaAI (Web Search API) ----- # 获取地址:https://bochaai.com/ # 用于定向爬取股票新闻时的搜索引擎 BOCHAAI_API_KEY=your-bochaai-api-key-here BOCHAAI_ENDPOINT=https://api.bochaai.com/v1/web-search # 默认值,可不填 # ========================================== # Embedding 配置 # ========================================== # EMBEDDING_PROVIDER=openai # EMBEDDING_MODEL=text-embedding-ada-002 # EMBEDDING_BATCH_SIZE=100 # EMBEDDING_BASE_URL= # 留空使用官方 API # 使用百炼 Embedding 时的配置示例: EMBEDDING_PROVIDER=openai EMBEDDING_MODEL=text-embedding-v4 EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 MILVUS_DIM=1024 # 百炼 embedding 是 1024 维 # ===== 爬取间隔配置(多源支持)===== CRAWL_INTERVAL_SINA=60 # 新浪财经爬取间隔(秒) CRAWL_INTERVAL_TENCENT=60 # 腾讯财经爬取间隔(秒) CRAWL_INTERVAL_JWVIEW=60 # 中新经纬爬取间隔(秒) CRAWL_INTERVAL_EEO=60 # 经济观察网爬取间隔(秒) CRAWL_INTERVAL_CAIJING=60 # 财经网爬取间隔(秒) CRAWL_INTERVAL_JINGJI21=60 # 21经济网爬取间隔(秒) # ===== 实时爬取与缓存配置 ===== CACHE_TTL=1800 # 缓存过期时间(秒),默认30分钟 NEWS_RETENTION_HOURS=24 # 新闻保留时间(小时),默认24小时 FRONTEND_REFETCH_INTERVAL=180 # 前端自动刷新间隔(秒),默认3分钟 # ===== 爬虫配置 ===== CRAWLER_TIMEOUT=30 CRAWLER_MAX_RETRIES=3 CRAWLER_DELAY=1.0 # ===== 安全配置 ===== SECRET_KEY=your-secret-key-here-please-change-in-production ACCESS_TOKEN_EXPIRE_MINUTES=10080 # ===== 日志配置 ===== LOG_LEVEL=INFO LOG_FILE=logs/finnews.log # ===== 业务配置 ===== MAX_NEWS_PER_REQUEST=50 NEWS_CACHE_TTL=3600 ================================================ FILE: backend/init_db.py ================================================ #!/usr/bin/env python """ 数据库初始化脚本 独立运行以创建数据库表 """ import sys import os # 添加当前目录到 Python 路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) if __name__ == "__main__": print("=" * 60) print("Initializing FinnewsHunter Database...") print("=" * 60) try: from sqlalchemy import create_engine from sqlalchemy.orm import declarative_base from app.core.config import settings # 导入所有模型 from app.models.database import Base from app.models.news import News from app.models.stock import Stock from app.models.analysis import Analysis from app.models.crawl_task import CrawlTask from app.models.debate_history import DebateHistory print(f"\nConnecting to database: {settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/{settings.POSTGRES_DB}") # 创建同步引擎 sync_engine = create_engine( settings.SYNC_DATABASE_URL, echo=False, pool_pre_ping=True, ) print("Creating tables...") Base.metadata.create_all(bind=sync_engine) print("\nDatabase initialized successfully!") print(f" - News table created") print(f" - Stock table created") print(f" - Analysis table created") print(f" - CrawlTask table created") print(f" - DebateHistory table created") print("=" * 60) sys.exit(0) except Exception as e: print(f"\nDatabase initialization failed: {e}") import traceback traceback.print_exc() print("=" * 60) print("\nNote: If tables already exist, this error is expected.") print("You can safely ignore it and proceed with starting the server.") sys.exit(0) # 即使失败也返回0,因为表可能已存在 ================================================ FILE: backend/init_knowledge_graph.py ================================================ #!/usr/bin/env python """ 初始化知识图谱 创建 Neo4j 约束、索引,并为示例股票构建图谱 """ import asyncio import logging import sys # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) async def init_knowledge_graph(): """初始化知识图谱""" try: from app.core.neo4j_client import get_neo4j_client from app.knowledge.graph_service import get_graph_service from app.knowledge.knowledge_extractor import ( create_knowledge_extractor, AkshareKnowledgeExtractor ) logger.info("=" * 80) logger.info("开始初始化知识图谱") logger.info("=" * 80) # 1. 测试 Neo4j 连接 logger.info("\n[1/4] 测试 Neo4j 连接...") neo4j_client = get_neo4j_client() if neo4j_client.health_check(): logger.info("Neo4j 连接正常") else: logger.error("Neo4j 连接失败,请检查配置") sys.exit(1) # 2. 初始化约束和索引 logger.info("\n[2/4] 初始化数据库约束和索引...") graph_service = get_graph_service() logger.info("约束和索引已创建") # 3. 为示例股票创建图谱 logger.info("\n[3/4] 为示例股票创建知识图谱...") example_stocks = [ ("SH600519", "贵州茅台"), # 示例1:大盘蓝筹 ("SZ300634", "彩讯股份"), # 示例2:中小板 ] extractor = create_knowledge_extractor() for stock_code, stock_name in example_stocks: logger.info(f"\n处理: {stock_name}({stock_code})") # 检查是否已存在 existing = graph_service.get_company_graph(stock_code) if existing: logger.info(f" 图谱已存在,跳过") continue # 从 akshare 获取信息 logger.info(f" 从 akshare 获取信息...") akshare_info = AkshareKnowledgeExtractor.extract_company_info(stock_code) if not akshare_info: logger.warning(f" akshare 未返回数据,跳过") continue # 使用 LLM 提取详细信息 logger.info(f" 使用 LLM 提取详细信息...") base_graph = await extractor.extract_from_akshare( stock_code, stock_name, akshare_info ) # 构建图谱 logger.info(f" 构建图谱...") success = graph_service.build_company_graph(base_graph) if success: stats = graph_service.get_graph_stats(stock_code) logger.info(f" 图谱构建成功: {stats}") else: logger.error(f" 图谱构建失败") # 4. 显示统计信息 logger.info("\n[4/4] 图谱统计...") companies = graph_service.list_all_companies() logger.info(f"当前共有 {len(companies)} 家公司的知识图谱") for company in companies: stats = graph_service.get_graph_stats(company['stock_code']) logger.info(f" - {company['stock_name']}({company['stock_code']}): {stats}") logger.info("\n" + "=" * 80) logger.info("知识图谱初始化完成!") logger.info("=" * 80) logger.info("\n下一步:") logger.info(" 1. 访问 http://localhost:7474 查看 Neo4j 浏览器") logger.info(" 2. 用户名: neo4j, 密码: finnews_neo4j_password") logger.info(" 3. 执行定向爬取时,系统会自动使用知识图谱进行多关键词并发检索") logger.info("\n") except Exception as e: logger.error(f"初始化失败: {e}", exc_info=True) sys.exit(1) if __name__ == "__main__": asyncio.run(init_knowledge_graph()) ================================================ FILE: backend/requirements.txt ================================================ # ===== Web 框架 ===== fastapi>=0.100.0 uvicorn[standard]>=0.22.0 pydantic>=2.0.0 pydantic-settings>=2.0.0 python-dotenv>=1.0.0 # ===== 数据库 ===== sqlalchemy>=2.0.0 asyncpg>=0.29.0 # PostgreSQL 异步驱动 psycopg2-binary>=2.9.0 # PostgreSQL 同步驱动(用于初始化) alembic>=1.12.0 # 数据库迁移工具 # ===== 缓存与任务队列 ===== redis>=4.5.0 celery>=5.3.0 # ===== 向量数据库 ===== pymilvus>=2.3.0 # ===== 图数据库 ===== neo4j>=5.14.0 # Neo4j Python驱动 # ===== 网络请求与爬虫 ===== requests>=2.31.0 beautifulsoup4>=4.12.0 lxml>=4.9.0 aiohttp>=3.9.0 markdownify>=0.11.0 # HTML 转 Markdown readabilipy>=0.2.0 # 智能内容提取(Mozilla Readability) playwright>=1.40.0 # JS 渲染(可选,需运行 playwright install) # ===== AI/ML ===== openai>=1.0.0 anthropic>=0.7.0 litellm>=1.0.0 tiktoken>=0.5.0 # Token 计数 # ===== 文本处理 ===== jieba>=0.42.1 # 中文分词 python-dateutil>=2.8.2 # ===== 工具库 ===== httpx>=0.25.0 tenacity>=8.2.0 # 重试机制 # ===== AgenticX 框架 ===== agenticx==0.1.9 # Docker 容器中使用 PyPI 版本 # 本地开发可以用:pip install -e ../../../../agenticx akshare ================================================ FILE: backend/reset_database.py ================================================ """ 清空数据库并重新开始 用于重置系统数据 """ import asyncio import sys from sqlalchemy import text from app.core.database import get_async_engine from app.core.redis_client import redis_client import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def reset_database(): """清空所有数据""" engine = get_async_engine() try: async with engine.begin() as conn: logger.info("=" * 60) logger.info("开始清空数据库...") logger.info("=" * 60) # 1. 清空新闻表 logger.info("清空新闻表 (news)...") result = await conn.execute(text("DELETE FROM news")) logger.info(f"✅ 已删除 {result.rowcount} 条新闻记录") # 2. 清空爬取任务表 logger.info("清空爬取任务表 (crawl_tasks)...") result = await conn.execute(text("DELETE FROM crawl_tasks")) logger.info(f"✅ 已删除 {result.rowcount} 条任务记录") # 3. 清空分析表(如果存在) try: logger.info("清空分析表 (analyses)...") result = await conn.execute(text("DELETE FROM analyses")) logger.info(f"✅ 已删除 {result.rowcount} 条分析记录") except Exception as e: logger.warning(f"清空分析表失败(表可能不存在): {e}") # 4. 重置自增ID logger.info("重置表自增ID...") try: await conn.execute(text("ALTER SEQUENCE news_id_seq RESTART WITH 1")) await conn.execute(text("ALTER SEQUENCE crawl_tasks_id_seq RESTART WITH 1")) await conn.execute(text("ALTER SEQUENCE analyses_id_seq RESTART WITH 1")) logger.info("✅ 自增ID已重置") except Exception as e: logger.warning(f"重置自增ID失败: {e}") logger.info("=" * 60) logger.info("数据库清空完成!") logger.info("=" * 60) # 5. 清空Redis缓存 if redis_client.is_available(): logger.info("清空Redis缓存...") try: # 删除所有news相关的缓存键 redis_client.client.flushdb() logger.info("✅ Redis缓存已清空") except Exception as e: logger.error(f"清空Redis失败: {e}") else: logger.warning("⚠️ Redis不可用,跳过缓存清理") logger.info("=" * 60) logger.info("✨ 数据重置完成!") logger.info("=" * 60) logger.info("下一步:") logger.info("1. 重启 Celery Worker 和 Beat") logger.info("2. 系统将自动开始爬取最新新闻") logger.info("3. 约5-10分钟后可在前端查看新数据") logger.info("=" * 60) except Exception as e: logger.error(f"❌ 清空数据失败: {e}") import traceback traceback.print_exc() sys.exit(1) finally: await engine.dispose() if __name__ == "__main__": # 确认操作 print("⚠️ 警告:此操作将删除所有新闻和任务数据!") print("⚠️ 此操作不可恢复!") confirm = input("确认要清空所有数据吗?(yes/no): ") if confirm.lower() in ['yes', 'y']: asyncio.run(reset_database()) else: print("❌ 操作已取消") sys.exit(0) ================================================ FILE: backend/setup_env.sh ================================================ #!/bin/bash # 环境变量快速配置脚本 echo "============================================" echo " FinnewsHunter 环境配置向导" echo "============================================" echo "" echo "请选择 LLM 服务商:" echo " 1) OpenAI 官方(默认)" echo " 2) 阿里云百炼(推荐国内用户)" echo " 3) 其他 OpenAI 代理" echo " 4) 手动配置(复制模板)" echo "" read -p "请输入选项 (1-4) [默认:1]: " choice choice=${choice:-1} case $choice in 1) # OpenAI 官方 cat > .env << 'EOF' # FinnewsHunter 环境配置 APP_NAME=FinnewsHunter DEBUG=True POSTGRES_USER=finnews POSTGRES_PASSWORD=finnews_dev_password POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_DB=finnews_db REDIS_HOST=localhost REDIS_PORT=6379 MILVUS_HOST=localhost MILVUS_PORT=19530 MILVUS_DIM=1536 # OpenAI 官方配置 LLM_PROVIDER=openai LLM_MODEL=gpt-3.5-turbo LLM_TEMPERATURE=0.7 LLM_MAX_TOKENS=2000 OPENAI_API_KEY=sk-your-openai-api-key-here EMBEDDING_PROVIDER=openai EMBEDDING_MODEL=text-embedding-ada-002 LOG_LEVEL=INFO EOF echo "" echo "OpenAI 配置已创建" echo "请编辑 .env 并填入你的 OPENAI_API_KEY" ;; 2) # 阿里云百炼 cat > .env << 'EOF' # FinnewsHunter 环境配置 APP_NAME=FinnewsHunter DEBUG=True POSTGRES_USER=finnews POSTGRES_PASSWORD=finnews_dev_password POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_DB=finnews_db REDIS_HOST=localhost REDIS_PORT=6379 MILVUS_HOST=localhost MILVUS_PORT=19530 MILVUS_DIM=1024 # 阿里云百炼配置(OpenAI 兼容模式) LLM_PROVIDER=openai LLM_MODEL=qwen-plus LLM_TEMPERATURE=0.7 LLM_MAX_TOKENS=2000 OPENAI_API_KEY=sk-your-bailian-api-key-here OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 EMBEDDING_PROVIDER=openai EMBEDDING_MODEL=text-embedding-v4 LOG_LEVEL=INFO EOF echo "" echo "百炼配置已创建" echo "请编辑 .env 并填入你的百炼 API Key" echo "获取 Key: https://dashscope.console.aliyun.com/" ;; 3) # 其他代理 cat > .env << 'EOF' # FinnewsHunter 环境配置 APP_NAME=FinnewsHunter DEBUG=True POSTGRES_USER=finnews POSTGRES_PASSWORD=finnews_dev_password POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_DB=finnews_db REDIS_HOST=localhost REDIS_PORT=6379 MILVUS_HOST=localhost MILVUS_PORT=19530 MILVUS_DIM=1536 # OpenAI 代理配置 LLM_PROVIDER=openai LLM_MODEL=gpt-3.5-turbo LLM_TEMPERATURE=0.7 LLM_MAX_TOKENS=2000 OPENAI_API_KEY=sk-your-proxy-api-key OPENAI_BASE_URL=https://your-proxy.com/v1 EMBEDDING_PROVIDER=openai EMBEDDING_MODEL=text-embedding-ada-002 LOG_LEVEL=INFO EOF echo "" echo "代理配置已创建" echo "请编辑 .env 并填入你的代理信息" ;; 4) # 手动配置 cp env.example .env echo "" echo "配置模板已复制" echo "请编辑 .env 并选择合适的配置方案" ;; *) echo "无效选项" exit 1 ;; esac echo "" read -p "是否现在编辑配置文件?(Y/n): " -n 1 -r echo if [[ ! $REPLY =~ ^[Nn]$ ]]; then ${EDITOR:-nano} .env fi echo "" echo "配置完成!运行 ./start.sh 启动服务" ================================================ FILE: backend/start.sh ================================================ #!/bin/bash # FinnewsHunter 启动脚本 set -e echo "===================================" echo " FinnewsHunter Backend Startup" echo "===================================" # 获取脚本所在目录(backend目录) SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" DEPLOY_DIR="$(cd "$SCRIPT_DIR/../deploy" && pwd)" # 1. 启动 Docker Compose 服务 echo "" echo "[1/4] Starting Docker Compose services..." cd "$DEPLOY_DIR" docker-compose -f docker-compose.dev.yml up -d # 等待数据库启动 echo "" echo "[2/4] Waiting for databases to be ready..." sleep 10 # 2. 初始化数据库(首次运行) echo "" echo "[3/4] Initializing database..." cd "$SCRIPT_DIR" python init_db.py || echo "Database initialization skipped (may already exist)" # 3. 启动 FastAPI 应用 echo "" echo "[4/4] Starting FastAPI application..." echo "" echo "Server will start at: http://localhost:8000" echo "API Documentation: http://localhost:8000/docs" echo "" # 确保在 backend 目录下启动 cd "$SCRIPT_DIR" uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ================================================ FILE: backend/start_celery.sh ================================================ #!/bin/bash # Celery 容器化重启脚本 # 用法: ./start_celery.sh [--restart|-r] [--force-recreate|-f] [--rebuild|-b] [--logs|-l] set -e # 解析命令行参数 AUTO_RESTART=false FORCE_RECREATE=false REBUILD_IMAGE=false SHOW_LOGS=false while [[ $# -gt 0 ]]; do case $1 in --restart|-r) AUTO_RESTART=true shift ;; --force-recreate|-f) FORCE_RECREATE=true AUTO_RESTART=true shift ;; --rebuild|-b) REBUILD_IMAGE=true FORCE_RECREATE=true AUTO_RESTART=true shift ;; --logs|-l) SHOW_LOGS=true shift ;; --help|-h) echo "用法: $0 [选项]" echo "" echo "选项:" echo " --restart, -r 自动重启容器(容器使用 python:3.11 基础镜像 + volumes 挂载)" echo " --force-recreate, -f 强制重建容器(会重新安装依赖,因为使用基础镜像)" echo " --rebuild, -b 重新构建镜像(构建的镜像不会被使用,仅用于清理未使用的镜像)" echo " --logs, -l 重启后自动显示日志" echo " --help, -h 显示帮助信息" echo "" echo "注意:" echo " - 当前容器使用 python:3.11 基础镜像 + volumes 挂载代码" echo " - 每次启动容器都会执行 pip install 安装依赖" echo " - --rebuild 选项会构建镜像,但构建的镜像不会被容器使用" echo "" echo "示例:" echo " $0 交互式重启容器" echo " $0 --restart 自动重启容器" echo " $0 -r -l 自动重启并显示日志" echo " $0 -f 强制重建容器(会重新安装依赖)" echo " $0 --rebuild 重新构建镜像(仅用于清理未使用的镜像)" exit 0 ;; *) echo "未知参数: $1" echo "使用 --help 查看帮助信息" exit 1 ;; esac done echo "============================================" echo " FinnewsHunter Celery 容器重启脚本" echo "============================================" echo "" # 获取脚本所在目录 SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" cd "$SCRIPT_DIR" # 检查 Docker 是否运行 if ! docker info > /dev/null 2>&1; then echo "Docker 未运行,请先启动 Docker" exit 1 fi # 检查 docker-compose 文件是否存在 COMPOSE_FILE="../deploy/docker-compose.dev.yml" if [ ! -f "$COMPOSE_FILE" ]; then echo "找不到 docker-compose 文件: $COMPOSE_FILE" exit 1 fi # 检查容器状态 echo "" echo "[1/4] 检查 Celery 容器状态..." WORKER_RUNNING=$(docker ps -q -f name=finnews_celery_worker) BEAT_RUNNING=$(docker ps -q -f name=finnews_celery_beat) if [ -n "$WORKER_RUNNING" ] || [ -n "$BEAT_RUNNING" ]; then echo "检测到 Celery 容器正在运行" echo " - Worker: $([ -n "$WORKER_RUNNING" ] && echo "运行中 ($WORKER_RUNNING)" || echo "未运行")" echo " - Beat: $([ -n "$BEAT_RUNNING" ] && echo "运行中 ($BEAT_RUNNING)" || echo "未运行")" if [ "$AUTO_RESTART" = false ]; then read -p "是否重启容器?(y/N): " -n 1 -r echo if [[ ! $REPLY =~ ^[Yy]$ ]]; then echo "已取消重启" exit 0 fi else echo "自动重启模式,无需确认" fi fi # 检查 Redis 是否运行 echo "" echo "[2/4] 检查 Redis 连接..." if docker exec finnews_redis redis-cli ping > /dev/null 2>&1; then echo "Redis 正常运行" else echo "Redis 未运行,请先启动 Docker Compose:" echo " cd ../deploy && docker-compose -f docker-compose.dev.yml up -d redis" exit 1 fi # 重启 Celery Worker 容器 echo "" cd ../deploy if [ "$REBUILD_IMAGE" = true ]; then echo "[3/5] 重新构建镜像(注意:构建的镜像不会被容器使用,仅用于清理未使用的镜像)..." docker-compose -f docker-compose.dev.yml build celery-worker celery-beat echo "[4/5] 强制重建 Celery Worker 容器(使用 python:3.11 基础镜像 + volumes 挂载)..." docker-compose -f docker-compose.dev.yml up -d --force-recreate celery-worker elif [ "$FORCE_RECREATE" = true ]; then echo "[3/4] 强制重建 Celery Worker 容器(使用 python:3.11 基础镜像,会重新安装依赖)..." docker-compose -f docker-compose.dev.yml up -d --force-recreate celery-worker else echo "[3/4] 重启 Celery Worker 容器(使用 python:3.11 基础镜像 + volumes 挂载)..." docker-compose -f docker-compose.dev.yml restart celery-worker fi WORKER_CONTAINER_ID=$(docker ps -q -f name=finnews_celery_worker) echo "Worker 容器已重启 (Container ID: $WORKER_CONTAINER_ID)" # 等待 Worker 启动 sleep 3 # 重启 Celery Beat 容器 echo "" if [ "$REBUILD_IMAGE" = true ]; then echo "[5/5] 强制重建 Celery Beat 容器(使用 python:3.11 基础镜像 + volumes 挂载)..." docker-compose -f docker-compose.dev.yml up -d --force-recreate celery-beat elif [ "$FORCE_RECREATE" = true ]; then echo "[4/4] 强制重建 Celery Beat 容器(使用 python:3.11 基础镜像,会重新安装依赖)..." docker-compose -f docker-compose.dev.yml up -d --force-recreate celery-beat else echo "[4/4] 重启 Celery Beat 容器(使用 python:3.11 基础镜像 + volumes 挂载)..." docker-compose -f docker-compose.dev.yml restart celery-beat fi BEAT_CONTAINER_ID=$(docker ps -q -f name=finnews_celery_beat) echo "Beat 容器已重启 (Container ID: $BEAT_CONTAINER_ID)" cd "$SCRIPT_DIR" echo "" echo "============================================" echo " Celery 容器重启成功!" echo "============================================" echo "" echo "容器信息:" echo " - Worker Container ID: $WORKER_CONTAINER_ID" echo " - Beat Container ID: $BEAT_CONTAINER_ID" echo "" echo "查看日志命令:" echo " - Worker 日志: docker logs -f finnews_celery_worker" echo " - Beat 日志: docker logs -f finnews_celery_beat" echo " - 最近100行: docker logs --tail 100 finnews_celery_worker" echo "" echo "监控命令:" echo " - 查看任务列表: curl http://localhost:8000/api/v1/tasks/" echo " - 查看容器状态: docker ps | grep celery" echo "" echo "实时监控已启动,每1分钟自动爬取新闻" echo "" echo "说明:" echo " - 容器使用 python:3.11 基础镜像 + volumes 挂载代码" echo " - 每次启动容器都会执行 pip install 安装依赖" echo " - 构建的镜像(deploy-celery-worker/beat)不会被使用,可以删除释放空间" echo "" echo "停止服务:" echo " cd ../deploy && docker-compose -f docker-compose.dev.yml stop celery-worker celery-beat" echo "" echo "完全重启(重建容器,会重新安装依赖):" echo " cd ../deploy && docker-compose -f docker-compose.dev.yml up -d --force-recreate celery-worker celery-beat" echo "" echo "============================================" if [ "$SHOW_LOGS" = true ]; then echo "" echo "正在监控日志(按 Ctrl+C 退出)..." echo "" sleep 2 docker logs -f --tail 50 finnews_celery_worker fi ================================================ FILE: backend/tests/__init__.py ================================================ """FinnewsHunter Tests""" ================================================ FILE: backend/tests/check_milvus_data.py ================================================ #!/usr/bin/env python3 """ 检查 Milvus 向量存储中的数据 """ import sys import os import asyncio # 添加项目路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from app.storage.vector_storage import get_vector_storage from app.core.config import settings def main(): try: print("=" * 60) print("Milvus 向量存储信息") print("=" * 60) storage = get_vector_storage() stats = storage.get_stats() print(f"\n📊 集合统计信息:") print(f" 集合名称: {stats['collection_name']}") print(f" 向量维度: {stats['dim']}") num_entities = stats['num_entities'] if isinstance(num_entities, str): print(f" 存储的向量数量: {num_entities}") else: print(f" 存储的向量数量: {num_entities}") if num_entities == 0: print(f" ⚠️ 注意:如果显示为 0,可能是 flush 失败导致统计不准确") print(f" Milvus地址: {storage.host}:{storage.port}") # 查询一些示例数据 print(f"\n📝 查询示例数据:") try: # 使用 agenticx 的 query 方法获取数据 from agenticx.storage.vectordb_storages.base import VectorDBQuery # 创建一个零向量查询来获取所有数据(top_k 限制结果数) zero_vector = [0.0] * stats['dim'] query = VectorDBQuery(query_vector=zero_vector, top_k=10) # query 是同步方法,可以直接调用 results = storage.milvus_storage.query(query) if results: print(f" ✅ 找到 {len(results)} 条记录") if isinstance(stats['num_entities'], str) or stats['num_entities'] != len(results): print(f" ℹ️ 统计数量: {stats['num_entities']}") print() for i, result in enumerate(results[:5], 1): # 只显示前5条 payload = result.record.payload or {} news_id = payload.get('news_id', result.record.id) text = payload.get('text', '') text_preview = text[:100] + "..." if len(text) > 100 else text print(f" {i}. 新闻ID: {news_id}") print(f" 文本预览: {text_preview}") if len(results) > 5: print(f"\n ... 还有 {len(results) - 5} 条记录未显示") else: if stats['num_entities'] == 0: print(" ⚠️ 未找到数据,集合可能确实为空") print(" 提示: 向量数据会在新闻分析时自动生成并存储") else: print(f" ⚠️ 未找到数据,但统计显示有 {stats['num_entities']} 条记录") print(" 可能的原因:数据在缓冲区中,需要等待 Milvus 自动刷新") except Exception as e: print(f" ❌ 无法查询数据: {e}") import traceback traceback.print_exc() if stats['num_entities'] == 0: print("\n 提示: 如果这是首次运行,集合可能确实为空") print("\n" + "=" * 60) print("💡 提示:") print(" - 向量数据存储在 Milvus 数据库中") print(" - 可以通过 Milvus 客户端工具查看完整数据") print(" - 向量维度必须与 embedding 模型匹配") print("=" * 60) except Exception as e: print(f"\n❌ 错误: {e}") print("\n可能的原因:") print(" 1. Milvus 服务未启动") print(" 2. Milvus 连接配置错误") print(" 3. 集合尚未创建") print("\n检查方法:") print(f" - 确认 Milvus 运行在 {settings.MILVUS_HOST}:{settings.MILVUS_PORT}") print(f" - 检查 .env 文件中的 MILVUS_* 配置") if __name__ == "__main__": main() ================================================ FILE: backend/tests/check_news_embedding_status.py ================================================ #!/usr/bin/env python3 """ 检查新闻的向量化状态 """ import sys import os import asyncio # 添加项目路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from sqlalchemy import select, func from app.core.database import get_db from app.models.news import News from app.models.analysis import Analysis async def main(): try: async for db in get_db(): # 统计总体情况 total_result = await db.execute(select(func.count(News.id))) total_news = total_result.scalar() or 0 embedded_result = await db.execute( select(func.count(News.id)).where(News.is_embedded == 1) ) embedded_count = embedded_result.scalar() or 0 analyzed_result = await db.execute( select(func.count(News.id)).where(News.sentiment_score.isnot(None)) ) analyzed_count = analyzed_result.scalar() or 0 # 查找已分析但未向量化的新闻 not_embedded_result = await db.execute( select(News.id, News.title, News.sentiment_score) .where( News.sentiment_score.isnot(None), News.is_embedded == 0 ) .order_by(News.id.desc()) .limit(10) ) not_embedded_news = not_embedded_result.all() print("=" * 60) print("新闻向量化状态统计") print("=" * 60) print(f"\n📊 总体统计:") print(f" 总新闻数: {total_news}") print(f" 已分析新闻: {analyzed_count}") print(f" 已向量化新闻: {embedded_count}") print(f" 已分析但未向量化: {analyzed_count - embedded_count}") if not_embedded_news: print(f"\n⚠️ 最近10条已分析但未向量化的新闻:") for news_id, title, sentiment_score in not_embedded_news: title_preview = title[:50] + "..." if len(title) > 50 else title print(f" - ID: {news_id}, 情感分数: {sentiment_score:.2f}") print(f" 标题: {title_preview}") else: print("\n✅ 所有已分析的新闻都已向量化") print("\n" + "=" * 60) print("💡 可能的原因:") print(" 1. Embedding API 超时(20秒超时)") print(" 2. Milvus 连接失败") print(" 3. Embedding 服务配置错误") print("\n🔧 解决方案:") print(" 1. 检查后端日志中的 embedding 错误") print(" 2. 确认 Milvus 服务正在运行") print(" 3. 检查 embedding API 配置(百炼/OpenAI)") print(" 4. 可以手动重新向量化这些新闻") print("=" * 60) except Exception as e: print(f"\n❌ 错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: backend/tests/financial/__init__.py ================================================ """Financial module tests""" ================================================ FILE: backend/tests/financial/test_smoke_openbb_models.py ================================================ """ 冒烟测试: Standard Models (P0-1, P0-2) 验证: - NewsQueryParams, NewsData 模型可正常实例化 - StockQueryParams, StockPriceData 模型可正常实例化 - 字段验证逻辑正确 - to_legacy_dict 兼容方法正常工作 运行: pytest -q -k "smoke_openbb_models" """ import pytest from datetime import datetime class TestNewsModels: """测试新闻相关模型""" def test_news_query_params_basic(self): """测试 NewsQueryParams 基本实例化""" from app.financial.models.news import NewsQueryParams # 默认参数 params = NewsQueryParams() assert params.limit == 50 assert params.keywords is None assert params.stock_codes is None # 自定义参数 params = NewsQueryParams( keywords=["茅台", "白酒"], stock_codes=["600519"], limit=20 ) assert params.keywords == ["茅台", "白酒"] assert params.stock_codes == ["600519"] assert params.limit == 20 def test_news_query_params_validation(self): """测试 NewsQueryParams 字段验证""" from app.financial.models.news import NewsQueryParams from pydantic import ValidationError # limit 边界测试 params = NewsQueryParams(limit=1) assert params.limit == 1 params = NewsQueryParams(limit=500) assert params.limit == 500 # limit 超出范围应报错 with pytest.raises(ValidationError): NewsQueryParams(limit=0) with pytest.raises(ValidationError): NewsQueryParams(limit=501) def test_news_data_basic(self): """测试 NewsData 基本实例化""" from app.financial.models.news import NewsData, NewsSentiment news = NewsData( id="test123", title="测试新闻标题", content="这是测试新闻的正文内容...", source="sina", source_url="https://finance.sina.com.cn/test", publish_time=datetime(2024, 1, 1, 10, 30) ) assert news.id == "test123" assert news.title == "测试新闻标题" assert news.source == "sina" assert news.sentiment is None # 可选字段默认 None assert news.stock_codes == [] # 默认空列表 def test_news_data_with_sentiment(self): """测试 NewsData 带情感标签""" from app.financial.models.news import NewsData, NewsSentiment news = NewsData( id="test456", title="利好消息", content="公司业绩超预期...", source="sina", source_url="https://example.com", publish_time=datetime.now(), sentiment=NewsSentiment.POSITIVE, sentiment_score=0.85 ) assert news.sentiment == NewsSentiment.POSITIVE assert news.sentiment_score == 0.85 def test_news_data_generate_id(self): """测试 NewsData.generate_id 方法""" from app.financial.models.news import NewsData url1 = "https://finance.sina.com.cn/news/123" url2 = "https://finance.sina.com.cn/news/456" id1 = NewsData.generate_id(url1) id2 = NewsData.generate_id(url2) # 相同 URL 生成相同 ID assert id1 == NewsData.generate_id(url1) # 不同 URL 生成不同 ID assert id1 != id2 # ID 长度为 16 assert len(id1) == 16 def test_news_data_to_legacy_dict(self): """测试 NewsData.to_legacy_dict 兼容方法""" from app.financial.models.news import NewsData news = NewsData( id="test789", title="测试标题", content="测试内容", source="sina", source_url="https://example.com/news", publish_time=datetime(2024, 6, 15, 14, 30), author="记者", stock_codes=["SH600519"] ) legacy = news.to_legacy_dict() # 验证字段映射 assert legacy["title"] == "测试标题" assert legacy["url"] == "https://example.com/news" # source_url → url assert legacy["source"] == "sina" assert legacy["author"] == "记者" assert "SH600519" in legacy["stock_codes"] class TestStockModels: """测试股票相关模型""" def test_stock_query_params_basic(self): """测试 StockQueryParams 基本实例化""" from app.financial.models.stock import ( StockQueryParams, KlineInterval, AdjustType ) # 最小参数 params = StockQueryParams(symbol="600519") assert params.symbol == "600519" assert params.interval == KlineInterval.DAILY assert params.adjust == AdjustType.QFQ assert params.limit == 90 # 自定义参数 params = StockQueryParams( symbol="SH600519", interval=KlineInterval.MIN_5, adjust=AdjustType.HFQ, limit=30 ) assert params.interval == KlineInterval.MIN_5 assert params.adjust == AdjustType.HFQ def test_stock_price_data_basic(self): """测试 StockPriceData 基本实例化""" from app.financial.models.stock import StockPriceData price = StockPriceData( symbol="600519", date=datetime(2024, 6, 15), open=1500.0, high=1520.0, low=1490.0, close=1510.0, volume=1000000 ) assert price.symbol == "600519" assert price.close == 1510.0 assert price.turnover is None # 可选字段 def test_stock_price_data_to_legacy_dict(self): """测试 StockPriceData.to_legacy_dict 兼容方法""" from app.financial.models.stock import StockPriceData price = StockPriceData( symbol="600519", date=datetime(2024, 6, 15, 10, 0, 0), open=1500.0, high=1520.0, low=1490.0, close=1510.0, volume=1000000, change_percent=0.67 ) legacy = price.to_legacy_dict() # 验证字段 assert legacy["date"] == "2024-06-15" assert legacy["close"] == 1510.0 assert legacy["change_percent"] == 0.67 assert "timestamp" in legacy # 应包含毫秒时间戳 def test_kline_interval_enum(self): """测试 KlineInterval 枚举""" from app.financial.models.stock import KlineInterval assert KlineInterval.MIN_1.value == "1m" assert KlineInterval.DAILY.value == "1d" assert KlineInterval("1d") == KlineInterval.DAILY def test_adjust_type_enum(self): """测试 AdjustType 枚举""" from app.financial.models.stock import AdjustType assert AdjustType.QFQ.value == "qfq" assert AdjustType.HFQ.value == "hfq" assert AdjustType("none") == AdjustType.NONE ================================================ FILE: backend/tests/financial/test_smoke_openbb_provider.py ================================================ """ 冒烟测试: Provider & Registry (P0-3, P0-4) 验证: - BaseFetcher 抽象类可被正确继承 - BaseProvider 抽象类可被正确继承 - ProviderRegistry 注册/获取/降级逻辑 - SinaProvider 正确注册 运行: pytest -q -k "smoke_openbb_provider" """ import pytest from typing import Dict, Any, List, Type from datetime import datetime class TestBaseFetcherAbstraction: """测试 BaseFetcher 抽象""" def test_fetcher_subclass_implementation(self): """测试 Fetcher 子类实现""" from app.financial.providers.base import BaseFetcher from app.financial.models.news import NewsQueryParams, NewsData class MockNewsFetcher(BaseFetcher[NewsQueryParams, NewsData]): query_model = NewsQueryParams data_model = NewsData def transform_query(self, params: NewsQueryParams) -> Dict[str, Any]: return {"limit": params.limit, "keywords": params.keywords} async def extract_data(self, query: Dict[str, Any]) -> List[Dict]: return [ {"title": "Test News", "content": "Content", "url": "http://test.com"} ] def transform_data(self, raw_data: List[Dict], query: NewsQueryParams) -> List[NewsData]: return [ NewsData( id=f"mock_{i}", title=item["title"], content=item["content"], source="mock", source_url=item["url"], publish_time=datetime.now() ) for i, item in enumerate(raw_data) ] fetcher = MockNewsFetcher() # 测试 transform_query params = NewsQueryParams(limit=10, keywords=["test"]) query = fetcher.transform_query(params) assert query["limit"] == 10 assert query["keywords"] == ["test"] @pytest.mark.asyncio async def test_fetcher_fetch_pipeline(self): """测试 Fetcher 完整 TET Pipeline""" from app.financial.providers.base import BaseFetcher from app.financial.models.news import NewsQueryParams, NewsData class MockFetcher(BaseFetcher[NewsQueryParams, NewsData]): query_model = NewsQueryParams data_model = NewsData def transform_query(self, params): return {"count": params.limit} async def extract_data(self, query): return [{"title": f"News {i}"} for i in range(query["count"])] def transform_data(self, raw_data, query): return [ NewsData( id=f"id_{i}", title=item["title"], content="content", source="mock", source_url="http://mock.com", publish_time=datetime.now() ) for i, item in enumerate(raw_data) ] fetcher = MockFetcher() params = NewsQueryParams(limit=5) results = await fetcher.fetch(params) assert len(results) == 5 assert all(isinstance(r, NewsData) for r in results) class TestBaseProviderAbstraction: """测试 BaseProvider 抽象""" def test_provider_subclass_implementation(self): """测试 Provider 子类实现""" from app.financial.providers.base import BaseProvider, BaseFetcher, ProviderInfo from app.financial.models.news import NewsQueryParams, NewsData class MockFetcher(BaseFetcher[NewsQueryParams, NewsData]): query_model = NewsQueryParams data_model = NewsData def transform_query(self, params): return {} async def extract_data(self, query): return [] def transform_data(self, raw_data, query): return [] class MockProvider(BaseProvider): @property def info(self) -> ProviderInfo: return ProviderInfo( name="mock", display_name="Mock Provider", description="For testing", priority=99 ) @property def fetchers(self) -> Dict[str, Type[BaseFetcher]]: return {"news": MockFetcher} provider = MockProvider() assert provider.info.name == "mock" assert provider.supports("news") is True assert provider.supports("stock_price") is False fetcher = provider.get_fetcher("news") assert fetcher is not None assert isinstance(fetcher, MockFetcher) class TestProviderRegistry: """测试 ProviderRegistry""" def test_registry_singleton(self): """测试 Registry 单例模式""" from app.financial.registry import ProviderRegistry r1 = ProviderRegistry() r2 = ProviderRegistry() assert r1 is r2 def test_registry_register_and_list(self): """测试注册和列出 Provider""" from app.financial.registry import reset_registry from app.financial.providers.base import BaseProvider, ProviderInfo, BaseFetcher from typing import Dict, Type registry = reset_registry() class MockProvider1(BaseProvider): @property def info(self): return ProviderInfo(name="p1", display_name="P1", description="", priority=2) @property def fetchers(self): return {} class MockProvider2(BaseProvider): @property def info(self): return ProviderInfo(name="p2", display_name="P2", description="", priority=1) @property def fetchers(self): return {} registry.register(MockProvider1()) registry.register(MockProvider2()) providers = registry.list_providers() assert "p1" in providers assert "p2" in providers # p2 优先级更高,应该在前面 assert providers.index("p2") < providers.index("p1") def test_registry_get_fetcher_auto_fallback(self): """测试获取 Fetcher 自动降级""" from app.financial.registry import reset_registry, FetcherNotFoundError from app.financial.providers.base import BaseProvider, ProviderInfo, BaseFetcher from app.financial.models.news import NewsQueryParams, NewsData from typing import Dict, Type from datetime import datetime registry = reset_registry() class MockFetcher(BaseFetcher[NewsQueryParams, NewsData]): query_model = NewsQueryParams data_model = NewsData def transform_query(self, params): return {} async def extract_data(self, query): return [] def transform_data(self, raw_data, query): return [] class ProviderA(BaseProvider): @property def info(self): return ProviderInfo(name="a", display_name="A", description="", priority=1) @property def fetchers(self): return {"news": MockFetcher} class ProviderB(BaseProvider): @property def info(self): return ProviderInfo(name="b", display_name="B", description="", priority=2) @property def fetchers(self): return {"news": MockFetcher, "stock": MockFetcher} registry.register(ProviderA()) registry.register(ProviderB()) # 获取 news:应该返回 ProviderA 的 (优先级更高) fetcher = registry.get_fetcher("news") assert fetcher is not None # 获取 stock:只有 ProviderB 支持 fetcher = registry.get_fetcher("stock") assert fetcher is not None # 获取不存在的类型 with pytest.raises(FetcherNotFoundError): registry.get_fetcher("nonexistent") def test_registry_get_fetcher_by_name(self): """测试指定 Provider 名称获取 Fetcher""" from app.financial.registry import reset_registry, ProviderNotFoundError from app.financial.providers.base import BaseProvider, ProviderInfo, BaseFetcher from app.financial.models.news import NewsQueryParams, NewsData registry = reset_registry() class MockFetcher(BaseFetcher[NewsQueryParams, NewsData]): query_model = NewsQueryParams data_model = NewsData def transform_query(self, params): return {} async def extract_data(self, query): return [] def transform_data(self, raw_data, query): return [] class MyProvider(BaseProvider): @property def info(self): return ProviderInfo(name="my", display_name="My", description="") @property def fetchers(self): return {"news": MockFetcher} registry.register(MyProvider()) # 指定存在的 Provider fetcher = registry.get_fetcher("news", provider="my") assert fetcher is not None # 指定不存在的 Provider with pytest.raises(ProviderNotFoundError): registry.get_fetcher("news", provider="nonexistent") class TestSinaProvider: """测试 SinaProvider""" def test_sina_provider_info(self): """测试 SinaProvider 元信息""" from app.financial.providers.sina import SinaProvider provider = SinaProvider() assert provider.info.name == "sina" assert provider.info.display_name == "新浪财经" assert provider.supports("news") is True def test_sina_provider_get_news_fetcher(self): """测试获取 SinaNewsFetcher""" from app.financial.providers.sina import SinaProvider from app.financial.providers.sina.fetchers.news import SinaNewsFetcher provider = SinaProvider() fetcher = provider.get_fetcher("news") assert fetcher is not None assert isinstance(fetcher, SinaNewsFetcher) def test_sina_news_fetcher_transform_query(self): """测试 SinaNewsFetcher.transform_query""" from app.financial.providers.sina.fetchers.news import SinaNewsFetcher from app.financial.models.news import NewsQueryParams fetcher = SinaNewsFetcher() # 无股票代码 params = NewsQueryParams(limit=10) query = fetcher.transform_query(params) assert query["limit"] == 10 assert "base_url" in query # 有股票代码 params = NewsQueryParams(stock_codes=["600519"], limit=20) query = fetcher.transform_query(params) assert "stock_urls" in query assert len(query["stock_urls"]) == 1 assert "sh600519" in query["stock_urls"][0].lower() ================================================ FILE: backend/tests/financial/test_smoke_openbb_tools.py ================================================ """ 冒烟测试: Financial Tools (P1-2) 验证: - FinancialNewsTool 可正常实例化 - Tool 在无 Provider 时返回错误而非崩溃 - Tool 正确调用 Registry 运行: pytest -q -k "smoke_openbb_tools" """ import pytest from unittest.mock import patch, AsyncMock, MagicMock from datetime import datetime class TestFinancialNewsTool: """测试 FinancialNewsTool""" def test_tool_instantiation(self): """测试工具实例化""" from app.financial.tools import FinancialNewsTool tool = FinancialNewsTool() assert tool.name == "financial_news" assert "金融新闻" in tool.description or "news" in tool.description.lower() def test_tool_has_required_methods(self): """测试工具具有必要方法""" from app.financial.tools import FinancialNewsTool tool = FinancialNewsTool() assert hasattr(tool, "execute") assert hasattr(tool, "aexecute") assert callable(tool.execute) assert callable(tool.aexecute) @pytest.mark.asyncio async def test_tool_returns_error_when_no_provider(self): """测试无 Provider 时返回错误""" from app.financial.tools import FinancialNewsTool from app.financial.registry import reset_registry # 清空 Registry reset_registry() tool = FinancialNewsTool() result = await tool.aexecute(limit=10) # 应返回错误而非崩溃 assert result["success"] is False assert "error" in result @pytest.mark.asyncio async def test_tool_with_mocked_fetcher(self): """测试工具与 Mock Fetcher 集成""" from app.financial.tools import FinancialNewsTool from app.financial.registry import reset_registry, get_registry from app.financial.providers.base import BaseProvider, ProviderInfo, BaseFetcher from app.financial.models.news import NewsQueryParams, NewsData registry = reset_registry() # 创建 Mock Fetcher class MockFetcher(BaseFetcher[NewsQueryParams, NewsData]): query_model = NewsQueryParams data_model = NewsData def transform_query(self, params): return {"limit": params.limit} async def extract_data(self, query): return [ {"title": "Mock News 1", "content": "Content 1", "url": "http://mock1.com"}, {"title": "Mock News 2", "content": "Content 2", "url": "http://mock2.com"}, ] def transform_data(self, raw_data, query): return [ NewsData( id=f"mock_{i}", title=item["title"], content=item["content"], source="mock", source_url=item["url"], publish_time=datetime.now() ) for i, item in enumerate(raw_data) ] class MockProvider(BaseProvider): @property def info(self): return ProviderInfo(name="mock", display_name="Mock", description="") @property def fetchers(self): return {"news": MockFetcher} registry.register(MockProvider()) tool = FinancialNewsTool() result = await tool.aexecute(limit=10) assert result["success"] is True assert result["count"] == 2 assert len(result["data"]) == 2 assert result["data"][0]["title"] == "Mock News 1" class TestStockPriceTool: """测试 StockPriceTool""" def test_tool_instantiation(self): """测试工具实例化""" from app.financial.tools import StockPriceTool tool = StockPriceTool() assert tool.name == "stock_price" assert "K线" in tool.description or "price" in tool.description.lower() @pytest.mark.asyncio async def test_tool_returns_error_for_invalid_interval(self): """测试无效参数时返回错误""" from app.financial.tools import StockPriceTool tool = StockPriceTool() result = await tool.aexecute(symbol="600519", interval="invalid_interval") assert result["success"] is False assert "error" in result @pytest.mark.asyncio async def test_tool_returns_error_when_no_provider(self): """测试无 Provider 时返回错误""" from app.financial.tools import StockPriceTool from app.financial.registry import reset_registry reset_registry() tool = StockPriceTool() result = await tool.aexecute(symbol="600519") assert result["success"] is False assert "error" in result class TestSetupDefaultProviders: """测试默认 Provider 设置""" def test_setup_registers_sina(self): """测试 setup_default_providers 注册 SinaProvider""" from app.financial.registry import reset_registry, get_registry from app.financial.tools import setup_default_providers registry = reset_registry() assert "sina" not in registry.list_providers() setup_default_providers() assert "sina" in registry.list_providers() def test_setup_idempotent(self): """测试 setup_default_providers 幂等性""" from app.financial.registry import reset_registry, get_registry from app.financial.tools import setup_default_providers reset_registry() # 多次调用不应报错 setup_default_providers() setup_default_providers() setup_default_providers() registry = get_registry() assert registry.list_providers().count("sina") == 1 ================================================ FILE: backend/tests/manual_vectorize.py ================================================ #!/usr/bin/env python3 """ 手动向量化新闻(用于修复未向量化的新闻) """ import sys import os import asyncio import logging # 添加项目路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # 先加载环境变量(避免循环导入) from dotenv import load_dotenv from pathlib import Path env_path = Path(__file__).parent / ".env" load_dotenv(env_path) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def vectorize_news_manually(news_id: int): """手动向量化单个新闻""" # 直接使用 SQLAlchemy 创建连接,避免循环导入 from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy import text from starlette.concurrency import run_in_threadpool # 从环境变量构建数据库 URL POSTGRES_USER = os.getenv("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost") POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432") POSTGRES_DB = os.getenv("POSTGRES_DB", "finnews_db") DATABASE_URL = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" # 创建引擎和会话 engine = create_async_engine(DATABASE_URL, echo=False) AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) try: # 使用原始 SQL 查询,避免导入模型 async with AsyncSessionLocal() as db: # 查询新闻数据 result = await db.execute( text("SELECT id, title, content, is_embedded FROM news WHERE id = :news_id"), {"news_id": news_id} ) row = result.first() if not row: print(f"❌ 新闻 {news_id} 不存在") return False news_id_db, title, content, is_embedded = row if is_embedded == 1: print(f"ℹ️ 新闻 {news_id} 已经向量化过了") return True print(f"🔄 开始向量化新闻 {news_id}: {title[:50]}...") # 获取服务(这些服务不依赖数据库连接) from app.services.embedding_service import get_embedding_service from app.storage.vector_storage import get_vector_storage embedding_service = get_embedding_service() vector_storage = get_vector_storage() # 组合文本 text_to_embed = f"{title}\n{content[:1000]}" # 生成向量(增加超时时间到60秒) print(" 📡 调用 embedding API...") embedding = await asyncio.wait_for( embedding_service.aembed_text(text_to_embed), timeout=60.0 # 增加到60秒 ) print(f" ✅ 向量生成成功,维度: {len(embedding)}") # 存储到 Milvus(设置超时,避免卡住) print(" 💾 存储到 Milvus...") try: await asyncio.wait_for( run_in_threadpool( vector_storage.store_embedding, news_id=news_id, embedding=embedding, text=text_to_embed ), timeout=30.0 # 30秒超时 ) print(" ✅ 存储成功") except asyncio.TimeoutError: print(" ⚠️ 存储超时(30秒),但数据可能已插入") # 即使超时,数据可能已经插入,只是flush还没完成 # 更新数据库标志 await db.execute( text("UPDATE news SET is_embedded = 1 WHERE id = :news_id"), {"news_id": news_id} ) await db.commit() print(f" ✅ 更新数据库标志成功") print(f"✅ 新闻 {news_id} 向量化完成!") return True except asyncio.TimeoutError: print(f"❌ 新闻 {news_id} 向量化超时(60秒)") return False except Exception as e: print(f"❌ 新闻 {news_id} 向量化失败: {e}") import traceback traceback.print_exc() return False finally: await engine.dispose() async def vectorize_all_pending(): """向量化所有未向量化但已分析的新闻""" # 直接使用 SQLAlchemy 创建连接,避免循环导入 from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy import text # 从环境变量构建数据库 URL POSTGRES_USER = os.getenv("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost") POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432") POSTGRES_DB = os.getenv("POSTGRES_DB", "finnews_db") DATABASE_URL = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" # 创建引擎和会话 engine = create_async_engine(DATABASE_URL, echo=False) AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) try: print("🔍 正在查找需要向量化的新闻...") async with AsyncSessionLocal() as db: # 使用原始 SQL 查询,避免导入模型 result = await db.execute( text(""" SELECT id, title FROM news WHERE sentiment_score IS NOT NULL AND is_embedded = 0 ORDER BY id DESC """) ) pending_news = result.all() print(f"📊 查询完成,找到 {len(pending_news) if pending_news else 0} 条记录") if not pending_news: print("✅ 没有需要向量化的新闻") return print(f"📊 找到 {len(pending_news)} 条需要向量化的新闻") print("=" * 60) success_count = 0 failed_count = 0 # 使用单个处理方式,但添加了超时保护 for news_id, title in pending_news: print(f"\n处理新闻 {news_id}...") if await vectorize_news_manually(news_id): success_count += 1 else: failed_count += 1 print("\n" + "=" * 60) print(f"📊 向量化完成统计:") print(f" 成功: {success_count}") print(f" 失败: {failed_count}") print("=" * 60) finally: await engine.dispose() async def main_async(): import sys print("🚀 脚本开始执行...") if len(sys.argv) > 1: try: # 向量化指定的新闻ID news_id = int(sys.argv[1]) print(f"📌 向量化指定的新闻: {news_id}") await vectorize_news_manually(news_id) except ValueError: # 如果不是数字,可能是 --no-wait 参数 if sys.argv[1] == "--no-wait": print("📌 向量化所有未向量化的新闻(跳过等待)") await vectorize_all_pending() else: print(f"❌ 无效的参数: {sys.argv[1]}") print("用法: python manual_vectorize.py [news_id|--no-wait]") else: # 向量化所有未向量化的新闻 print("⚠️ 这将向量化所有已分析但未向量化的新闻") print(" 按 Ctrl+C 取消,或等待5秒后继续...") print(" (使用 --no-wait 参数可跳过等待)") try: await asyncio.sleep(5) except KeyboardInterrupt: print("\n已取消") sys.exit(0) await vectorize_all_pending() print("✅ 脚本执行完成") if __name__ == "__main__": asyncio.run(main_async()) ================================================ FILE: backend/tests/test_alpha_mining/__init__.py ================================================ """Alpha Mining 测试模块""" ================================================ FILE: backend/tests/test_alpha_mining/test_integration_p2.py ================================================ """ P2 集成测试 - Alpha Mining 完整集成 测试覆盖: - F18: QuantitativeAgent 集成 - F19: REST API 端点 - 完整工作流测试 """ import pytest import sys from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import asyncio # 添加项目路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) # ============================================================================ # F18: QuantitativeAgent 集成测试 # ============================================================================ class TestQuantitativeAgent: """量化分析智能体测试""" def test_agent_import(self): """测试 Agent 可导入""" from app.agents.quantitative_agent import QuantitativeAgent, create_quantitative_agent assert QuantitativeAgent is not None assert create_quantitative_agent is not None def test_agent_init_without_llm(self): """测试不使用 LLM 初始化""" from app.agents.quantitative_agent import QuantitativeAgent agent = QuantitativeAgent( llm_provider=None, enable_alpha_mining=True ) assert agent.enable_alpha_mining is True assert agent._alpha_mining_initialized is False def test_agent_lazy_init(self): """测试延迟初始化""" from app.agents.quantitative_agent import QuantitativeAgent agent = QuantitativeAgent(enable_alpha_mining=True) # 初始时未初始化 assert agent._generator is None assert agent._vm is None # 调用 _init_alpha_mining agent._init_alpha_mining() # 现在应该已初始化 assert agent._alpha_mining_initialized is True assert agent._generator is not None assert agent._vm is not None @pytest.mark.asyncio async def test_agent_mine_factors(self): """测试因子挖掘功能""" from app.agents.quantitative_agent import QuantitativeAgent agent = QuantitativeAgent(enable_alpha_mining=True) result = await agent._mine_factors( stock_code="000001", stock_name="测试股票", market_data=None, sentiment_data=None ) assert "factors" in result assert "stats" in result assert isinstance(result["factors"], list) @pytest.mark.asyncio async def test_agent_full_analysis(self): """测试完整分析流程(无 LLM)""" from app.agents.quantitative_agent import QuantitativeAgent agent = QuantitativeAgent( llm_provider=None, enable_alpha_mining=True ) result = await agent.analyze( stock_code="000001", stock_name="平安银行", market_data=None, sentiment_data=None, context="" ) assert result["success"] is True assert result["stock_code"] == "000001" assert "factors_discovered" in result @pytest.mark.asyncio async def test_agent_with_mock_llm(self): """测试使用 Mock LLM""" from app.agents.quantitative_agent import QuantitativeAgent # 创建 Mock LLM mock_llm = AsyncMock() mock_llm.chat = AsyncMock(return_value='{"trend": "上涨", "confidence": 0.7}') agent = QuantitativeAgent( llm_provider=mock_llm, enable_alpha_mining=True ) # 准备模拟数据 import torch market_data = { "close": torch.randn(100).abs() * 100 + 50, "volume": torch.randn(100).abs() * 1e6 } result = await agent.analyze( stock_code="000001", stock_name="平安银行", market_data=market_data, context="测试上下文" ) assert result["success"] is True assert len(result["factors_discovered"]) >= 0 def test_agent_evaluate_factor(self): """测试因子评估""" from app.agents.quantitative_agent import QuantitativeAgent agent = QuantitativeAgent(enable_alpha_mining=True) # 同步包装异步调用 loop = asyncio.get_event_loop() result = loop.run_until_complete( agent.evaluate_factor("ADD RET VOL") ) # 可能成功或失败,取决于公式解析 assert "success" in result def test_agent_get_best_factors(self): """测试获取最优因子""" from app.agents.quantitative_agent import QuantitativeAgent agent = QuantitativeAgent(enable_alpha_mining=True) # 手动添加一些因子 agent.discovered_factors = [ {"formula_str": "ADD(RET, VOL)", "sortino": 1.5}, {"formula_str": "MUL(RET, MA5(VOL))", "sortino": 0.8}, {"formula_str": "SUB(RET, DELTA1(VOL))", "sortino": 2.0}, ] best = agent.get_best_factors(top_k=2) assert len(best) == 2 assert best[0]["sortino"] == 2.0 # 最高分在前 # ============================================================================ # F19: REST API 测试 # ============================================================================ class TestAlphaMiningAPI: """Alpha Mining REST API 测试""" def test_api_module_import(self): """测试 API 模块可导入""" from app.api.v1.alpha_mining import router assert router is not None assert router.prefix == "/alpha-mining" def test_api_routes_exist(self): """测试 API 路由存在""" from app.api.v1.alpha_mining import router routes = [r.path for r in router.routes] assert "/mine" in routes assert "/evaluate" in routes assert "/generate" in routes assert "/factors" in routes assert "/status/{task_id}" in routes assert "/operators" in routes @pytest.fixture def test_client(self): """创建测试客户端""" try: from fastapi.testclient import TestClient from app.main import app return TestClient(app) except ImportError: pytest.skip("FastAPI test client not available") def test_get_operators(self, test_client): """测试获取操作符列表""" if test_client is None: pytest.skip("Test client not available") response = test_client.get("/api/v1/alpha-mining/operators") assert response.status_code == 200 data = response.json() assert data["success"] is True assert "features" in data assert "operators" in data def test_get_factors_empty(self, test_client): """测试获取因子列表(空)""" if test_client is None: pytest.skip("Test client not available") response = test_client.get("/api/v1/alpha-mining/factors") assert response.status_code == 200 data = response.json() assert data["success"] is True assert "factors" in data def test_evaluate_factor(self, test_client): """测试因子评估端点""" if test_client is None: pytest.skip("Test client not available") response = test_client.post( "/api/v1/alpha-mining/evaluate", json={"formula": "RET"} ) assert response.status_code == 200 data = response.json() assert "success" in data def test_generate_factors(self, test_client): """测试因子生成端点""" if test_client is None: pytest.skip("Test client not available") response = test_client.post( "/api/v1/alpha-mining/generate", json={"batch_size": 5, "max_len": 6} ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert "factors" in data # ============================================================================ # 完整工作流测试 # ============================================================================ class TestFullWorkflow: """完整工作流测试""" @pytest.mark.asyncio async def test_end_to_end_factor_discovery(self): """端到端因子发现流程""" import torch # 1. 准备数据 from app.alpha_mining import ( AlphaMiningConfig, FactorVocab, FactorVM, AlphaGenerator, AlphaTrainer, FactorEvaluator, MarketFeatureBuilder, SentimentFeatureBuilder, generate_mock_data ) # 2. 初始化组件 config = AlphaMiningConfig( d_model=32, num_layers=1, batch_size=8, max_seq_len=6 ) vocab = FactorVocab() vm = FactorVM(vocab=vocab) generator = AlphaGenerator(vocab=vocab, config=config) evaluator = FactorEvaluator(config=config) # 3. 生成模拟数据 features, returns = generate_mock_data( num_samples=30, num_features=6, time_steps=100, seed=42 ) # 4. 创建训练器并训练 trainer = AlphaTrainer( generator=generator, vocab=vocab, config=config ) result = trainer.train( features=features, returns=returns, num_steps=5, # 少量步数用于测试 progress_bar=False ) assert result["total_steps"] == 5 assert "best_score" in result # 5. 验证最优因子 if result["best_formula"]: factor = vm.execute(result["best_formula"], features) assert factor is not None or factor is None # 可能无效 if factor is not None: metrics = evaluator.evaluate(factor, returns) assert "sortino_ratio" in metrics print("\n✅ End-to-end factor discovery test passed!") @pytest.mark.asyncio async def test_quantitative_agent_workflow(self): """量化智能体工作流测试""" from app.agents.quantitative_agent import QuantitativeAgent import torch # 创建智能体 agent = QuantitativeAgent(enable_alpha_mining=True) # 准备数据 market_data = { "close": torch.randn(252).abs() * 100 + 50, "volume": torch.randn(252).abs() * 1e6 } sentiment_data = { "sentiment": torch.randn(252).tolist(), "news_count": torch.abs(torch.randn(252)).tolist() } # 执行分析 result = await agent.analyze( stock_code="600000", stock_name="浦发银行", market_data=market_data, sentiment_data=sentiment_data, context="银行股分析" ) assert result["success"] is True assert result["stock_code"] == "600000" assert "factors_discovered" in result print("\n✅ QuantitativeAgent workflow test passed!") print(f" - Factors discovered: {len(result['factors_discovered'])}") def test_api_and_agent_integration(self): """API 和 Agent 集成测试""" from app.agents.quantitative_agent import create_quantitative_agent # 创建智能体 agent = create_quantitative_agent(enable_alpha_mining=True) # 验证组件 agent._init_alpha_mining() assert agent._generator is not None assert agent._vm is not None assert agent._evaluator is not None # 验证因子生成 formulas, _ = agent._generator.generate(batch_size=3, max_len=5) assert len(formulas) == 3 # 验证因子执行 from app.alpha_mining import generate_mock_data features, returns = generate_mock_data(num_samples=10, time_steps=50) valid_count = 0 for formula in formulas: factor = agent._vm.execute(formula, features) if factor is not None: valid_count += 1 print(f"\n✅ API-Agent integration test passed!") print(f" - Generated: {len(formulas)}, Valid: {valid_count}") # ============================================================================ # 性能测试 # ============================================================================ class TestPerformance: """性能测试""" def test_generator_speed(self): """测试生成器速度""" import time from app.alpha_mining import AlphaGenerator, AlphaMiningConfig config = AlphaMiningConfig(d_model=64, num_layers=2) generator = AlphaGenerator(config=config) # 预热 generator.generate(batch_size=10, max_len=8) # 计时 start = time.time() for _ in range(10): generator.generate(batch_size=100, max_len=8) elapsed = time.time() - start avg_time = elapsed / 10 print(f"\n📊 Generator speed: {avg_time*1000:.2f}ms per batch (100 factors)") assert avg_time < 5.0 # 应该在 5 秒内完成 def test_vm_execution_speed(self): """测试 VM 执行速度""" import time import torch from app.alpha_mining import FactorVM, FactorVocab, generate_mock_data vm = FactorVM() vocab = FactorVocab() features, _ = generate_mock_data(num_samples=100, time_steps=252) # 创建测试公式 formulas = [ [0], # RET [0, 1, vocab.name_to_token("ADD")], # ADD(RET, VOL) [0, vocab.name_to_token("MA5")], # MA5(RET) ] # 计时 start = time.time() for _ in range(100): for formula in formulas: vm.execute(formula, features) elapsed = time.time() - start avg_time = elapsed / (100 * len(formulas)) print(f"\n📊 VM execution speed: {avg_time*1000:.3f}ms per formula") assert avg_time < 0.1 # 应该在 100ms 内完成 if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: backend/tests/test_alpha_mining/test_smoke_p0.py ================================================ """ P0 冒烟测试 - Alpha Mining 核心机制 测试覆盖: - F02: 配置模块 - F03-F04: 操作符和时序函数 - F05: 词汇表 - F06-F07: FactorVM 执行和解码 - F08-F09: AlphaGenerator 模型和生成 - F10: AlphaTrainer 训练 - F11: 模拟数据生成 """ import pytest import torch import sys from pathlib import Path # 添加项目路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from app.alpha_mining.config import AlphaMiningConfig, DEFAULT_CONFIG from app.alpha_mining.dsl.ops import ( OPS_CONFIG, ts_delay, ts_delta, ts_mean, ts_std, get_op_names ) from app.alpha_mining.dsl.vocab import FactorVocab, FEATURES, DEFAULT_VOCAB from app.alpha_mining.vm.factor_vm import FactorVM from app.alpha_mining.model.alpha_generator import AlphaGenerator from app.alpha_mining.model.trainer import AlphaTrainer from app.alpha_mining.utils import generate_mock_data # ============================================================================ # F02: 配置模块测试 # ============================================================================ class TestConfig: """配置模块测试""" def test_default_config_exists(self): """测试默认配置存在""" assert DEFAULT_CONFIG is not None assert isinstance(DEFAULT_CONFIG, AlphaMiningConfig) def test_config_device(self): """测试设备配置""" config = AlphaMiningConfig() assert config.device in ["cpu", "cuda", "mps"] assert isinstance(config.torch_device, torch.device) def test_config_features(self): """测试特征配置""" config = AlphaMiningConfig() assert len(config.market_features) >= 4 assert len(config.all_features) >= 4 assert config.num_features > 0 # ============================================================================ # F03-F04: 操作符测试 # ============================================================================ class TestOps: """操作符测试""" @pytest.fixture def sample_tensor(self): """创建测试张量""" return torch.randn(10, 100) # [batch=10, time=100] def test_ts_delay(self, sample_tensor): """测试时序延迟""" result = ts_delay(sample_tensor, d=1) assert result.shape == sample_tensor.shape # 第一列应该是 0 assert (result[:, 0] == 0).all() # 后续应该是原始值的延迟 assert torch.allclose(result[:, 1:], sample_tensor[:, :-1]) def test_ts_delta(self, sample_tensor): """测试时序差分""" result = ts_delta(sample_tensor, d=1) assert result.shape == sample_tensor.shape # 差分 = x[t] - x[t-1] expected = sample_tensor - ts_delay(sample_tensor, 1) assert torch.allclose(result, expected) def test_ts_mean(self, sample_tensor): """测试滑动平均""" result = ts_mean(sample_tensor, window=5) assert result.shape == sample_tensor.shape # 值应该在合理范围内 assert not torch.isnan(result).any() def test_ts_std(self, sample_tensor): """测试滑动标准差""" result = ts_std(sample_tensor, window=5) assert result.shape == sample_tensor.shape # 标准差应该非负 assert (result >= 0).all() def test_ops_config_complete(self): """测试操作符配置完整性""" assert len(OPS_CONFIG) >= 10 for name, func, arity in OPS_CONFIG: assert isinstance(name, str) assert callable(func) assert arity in [1, 2, 3] def test_all_ops_executable(self, sample_tensor): """测试所有操作符可执行""" y = torch.randn_like(sample_tensor) z = torch.randn_like(sample_tensor) for name, func, arity in OPS_CONFIG: try: if arity == 1: result = func(sample_tensor) elif arity == 2: result = func(sample_tensor, y) elif arity == 3: result = func(sample_tensor, y, z) assert result.shape == sample_tensor.shape, f"{name} shape mismatch" assert not torch.isnan(result).all(), f"{name} all NaN" except Exception as e: pytest.fail(f"Operator {name} failed: {e}") # ============================================================================ # F05: 词汇表测试 # ============================================================================ class TestVocab: """词汇表测试""" def test_default_vocab_exists(self): """测试默认词汇表存在""" assert DEFAULT_VOCAB is not None assert DEFAULT_VOCAB.vocab_size > 0 def test_vocab_token_mapping(self): """测试 token 映射""" vocab = FactorVocab() # 测试特征映射 assert vocab.token_to_name(0) == FEATURES[0] assert vocab.name_to_token(FEATURES[0]) == 0 # 测试操作符映射 op_names = get_op_names() first_op_token = vocab.num_features assert vocab.token_to_name(first_op_token) == op_names[0] def test_vocab_is_feature_operator(self): """测试特征/操作符判断""" vocab = FactorVocab() # 特征 token assert vocab.is_feature(0) assert not vocab.is_operator(0) # 操作符 token op_token = vocab.num_features assert vocab.is_operator(op_token) assert not vocab.is_feature(op_token) def test_vocab_get_operator_arity(self): """测试获取操作符参数数量""" vocab = FactorVocab() for i, (name, func, arity) in enumerate(OPS_CONFIG): token = vocab.num_features + i assert vocab.get_operator_arity(token) == arity # ============================================================================ # F06-F07: FactorVM 测试 # ============================================================================ class TestFactorVM: """因子执行器测试""" @pytest.fixture def vm(self): """创建 VM 实例""" return FactorVM() @pytest.fixture def features(self): """创建测试特征""" # [batch=10, num_features=6, time=100] return torch.randn(10, 6, 100) def test_vm_execute_simple(self, vm, features): """测试简单表达式执行""" # 只取第一个特征 formula = [0] # RET result = vm.execute(formula, features) assert result is not None assert result.shape == (10, 100) assert torch.allclose(result, features[:, 0, :]) def test_vm_execute_binary_op(self, vm, features): """测试二元操作""" vocab = vm.vocab add_token = vocab.name_to_token("ADD") # ADD(RET, VOL) = features[0] + features[1] formula = [0, 1, add_token] result = vm.execute(formula, features) assert result is not None expected = features[:, 0, :] + features[:, 1, :] assert torch.allclose(result, expected) def test_vm_execute_unary_op(self, vm, features): """测试一元操作""" vocab = vm.vocab neg_token = vocab.name_to_token("NEG") # NEG(RET) = -features[0] formula = [0, neg_token] result = vm.execute(formula, features) assert result is not None expected = -features[:, 0, :] assert torch.allclose(result, expected) def test_vm_execute_invalid_formula(self, vm, features): """测试无效公式""" vocab = vm.vocab add_token = vocab.name_to_token("ADD") # 只有一个参数的 ADD(无效) formula = [0, add_token] result = vm.execute(formula, features) assert result is None # 应该返回 None def test_vm_decode_simple(self, vm): """测试表达式解码""" # RET assert "RET" in vm.decode([0]) # ADD(RET, VOL) vocab = vm.vocab add_token = vocab.name_to_token("ADD") decoded = vm.decode([0, 1, add_token]) assert "ADD" in decoded assert "RET" in decoded def test_vm_validate(self, vm): """测试表达式验证""" vocab = vm.vocab add_token = vocab.name_to_token("ADD") neg_token = vocab.name_to_token("NEG") # 有效公式 assert vm.validate([0]) # RET assert vm.validate([0, neg_token]) # NEG(RET) assert vm.validate([0, 1, add_token]) # ADD(RET, VOL) # 无效公式 assert not vm.validate([add_token]) # ADD without args assert not vm.validate([0, 1]) # Two features, no op # ============================================================================ # F08-F09: AlphaGenerator 测试 # ============================================================================ class TestAlphaGenerator: """因子生成器测试""" @pytest.fixture def generator(self): """创建生成器实例""" config = AlphaMiningConfig(d_model=32, num_layers=1) # 小模型用于测试 return AlphaGenerator(config=config) def test_generator_init(self, generator): """测试生成器初始化""" assert generator.vocab_size > 0 assert generator.d_model > 0 def test_generator_forward(self, generator): """测试前向传播""" batch_size = 4 seq_len = 5 tokens = torch.zeros((batch_size, seq_len), dtype=torch.long) logits, value = generator(tokens) assert logits.shape == (batch_size, generator.vocab_size) assert value.shape == (batch_size, 1) def test_generator_generate(self, generator): """测试生成功能""" batch_size = 8 max_len = 6 formulas, log_probs = generator.generate( batch_size=batch_size, max_len=max_len ) assert len(formulas) == batch_size assert all(len(f) == max_len for f in formulas) assert len(log_probs) == batch_size def test_generator_generate_with_training(self, generator): """测试训练模式生成""" batch_size = 4 max_len = 6 sequences, log_probs, values = generator.generate_with_training( batch_size=batch_size, max_len=max_len ) assert sequences.shape == (batch_size, max_len) assert len(log_probs) == max_len assert len(values) == max_len # ============================================================================ # F10: AlphaTrainer 测试 # ============================================================================ class TestAlphaTrainer: """训练器测试""" @pytest.fixture def trainer(self): """创建训练器实例""" config = AlphaMiningConfig( d_model=32, num_layers=1, batch_size=16, max_seq_len=6 ) return AlphaTrainer(config=config) @pytest.fixture def mock_data(self): """创建模拟数据""" return generate_mock_data( num_samples=20, num_features=6, time_steps=50, seed=42 ) def test_trainer_init(self, trainer): """测试训练器初始化""" assert trainer.generator is not None assert trainer.vm is not None assert trainer.best_score == -float('inf') def test_trainer_train_step(self, trainer, mock_data): """测试单步训练""" features, returns = mock_data metrics = trainer.train_step(features, returns) assert "loss" in metrics assert "avg_reward" in metrics assert "valid_ratio" in metrics assert trainer.step_count == 1 def test_trainer_short_training(self, trainer, mock_data): """测试短训练(3步)""" features, returns = mock_data result = trainer.train( features, returns, num_steps=3, progress_bar=False ) assert result["total_steps"] == 3 assert "best_score" in result assert len(trainer.training_history) == 3 # ============================================================================ # F11: 模拟数据测试 # ============================================================================ class TestMockData: """模拟数据生成测试""" def test_generate_mock_data_shape(self): """测试模拟数据形状""" features, returns = generate_mock_data( num_samples=50, num_features=6, time_steps=100 ) assert features.shape == (50, 6, 100) assert returns.shape == (50, 100) def test_generate_mock_data_no_nan(self): """测试模拟数据无 NaN""" features, returns = generate_mock_data() assert not torch.isnan(features).any() assert not torch.isnan(returns).any() def test_generate_mock_data_reproducible(self): """测试模拟数据可复现""" f1, r1 = generate_mock_data(seed=42) f2, r2 = generate_mock_data(seed=42) assert torch.allclose(f1, f2) assert torch.allclose(r1, r2) # ============================================================================ # 端到端冒烟测试 # ============================================================================ class TestEndToEnd: """端到端测试""" def test_full_pipeline_smoke(self): """完整流程冒烟测试""" # 1. 创建配置 config = AlphaMiningConfig( d_model=32, num_layers=1, batch_size=8, max_seq_len=6 ) # 2. 创建组件 vocab = FactorVocab() vm = FactorVM(vocab=vocab) generator = AlphaGenerator(vocab=vocab, config=config) trainer = AlphaTrainer(generator=generator, vocab=vocab, config=config) # 3. 生成模拟数据 features, returns = generate_mock_data( num_samples=10, num_features=6, time_steps=30, seed=42 ) # 4. 生成因子表达式 formulas, _ = generator.generate(batch_size=4, max_len=5) # 5. 执行表达式 valid_count = 0 for formula in formulas: result = vm.execute(formula, features) if result is not None: valid_count += 1 decoded = vm.decode(formula) assert isinstance(decoded, str) # 6. 训练(1步) metrics = trainer.train_step(features, returns) assert metrics["step"] == 1 print(f"\n✅ End-to-end smoke test passed!") print(f" - Valid formulas: {valid_count}/{len(formulas)}") print(f" - Avg reward: {metrics['avg_reward']:.4f}") if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: backend/tests/test_alpha_mining/test_smoke_p1.py ================================================ """ P1 冒烟测试 - Alpha Mining 数据集成 测试覆盖: - F13: MarketFeatureBuilder - F14: SentimentFeatureBuilder - F15: FactorEvaluator - F16: AlphaMiningTool """ import pytest import torch import pandas as pd import numpy as np import sys from pathlib import Path from datetime import datetime, timedelta # 添加项目路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from app.alpha_mining.config import AlphaMiningConfig, DEFAULT_CONFIG from app.alpha_mining.features.market import MarketFeatureBuilder from app.alpha_mining.features.sentiment import SentimentFeatureBuilder from app.alpha_mining.backtest.evaluator import FactorEvaluator from app.alpha_mining.utils import generate_mock_data # ============================================================================ # F13: MarketFeatureBuilder 测试 # ============================================================================ class TestMarketFeatureBuilder: """行情特征构建器测试""" @pytest.fixture def builder(self): return MarketFeatureBuilder() @pytest.fixture def sample_df(self): """创建示例 DataFrame""" dates = pd.date_range("2024-01-01", periods=100, freq="D") np.random.seed(42) return pd.DataFrame({ "date": dates, "close": 100 * np.exp(np.cumsum(np.random.randn(100) * 0.02)), "volume": np.abs(np.random.randn(100)) * 1e6 + 1e6, "turnover": np.abs(np.random.randn(100)) * 0.05, }).set_index("date") def test_build_from_dataframe(self, builder, sample_df): """测试从 DataFrame 构建特征""" features = builder.build(sample_df) assert features.dim() == 3 # [batch, features, time] assert features.size(0) == 1 # batch=1 assert features.size(1) == 4 # 4 个特征 assert features.size(2) == 100 # time_steps def test_build_from_tensors(self, builder): """测试从张量字典构建特征""" data = { "close": torch.randn(10, 100).abs() * 100 + 50, "volume": torch.randn(10, 100).abs() * 1e6, } features = builder.build(data) assert features.shape == (10, 4, 100) def test_features_normalized(self, builder, sample_df): """测试特征被正确标准化""" features = builder.build(sample_df) # 检查值在合理范围内 assert features.max() <= 5.0 assert features.min() >= -5.0 def test_no_nan_in_features(self, builder, sample_df): """测试特征无 NaN""" features = builder.build(sample_df) assert not torch.isnan(features).any() assert not torch.isinf(features).any() def test_feature_names(self, builder): """测试特征名称""" names = builder.get_feature_names() assert "RET" in names assert "VOL" in names assert "VOLUME_CHG" in names assert "TURNOVER" in names # ============================================================================ # F14: SentimentFeatureBuilder 测试 # ============================================================================ class TestSentimentFeatureBuilder: """情感特征构建器测试""" @pytest.fixture def builder(self): return SentimentFeatureBuilder() @pytest.fixture def sample_df(self): """创建示例 DataFrame""" dates = pd.date_range("2024-01-01", periods=50, freq="D") np.random.seed(42) return pd.DataFrame({ "date": dates, "sentiment": np.random.randn(50) * 0.3, "news_count": np.abs(np.random.randn(50)) * 5 + 1, }).set_index("date") def test_build_from_dataframe(self, builder, sample_df): """测试从 DataFrame 构建特征""" features = builder.build(sample_df) assert features.dim() == 3 assert features.size(0) == 1 assert features.size(1) == 2 # SENTIMENT, NEWS_COUNT assert features.size(2) == 50 def test_build_from_dict(self, builder): """测试从字典构建特征""" data = { "sentiment": [0.1, -0.2, 0.3, 0.0, -0.1], "news_count": [5, 3, 8, 2, 4] } features = builder.build(data) assert features.shape == (1, 2, 5) def test_build_from_list(self, builder): """测试从列表构建特征""" data = [ {"sentiment": 0.1, "news_count": 5}, {"sentiment": -0.2, "news_count": 3}, {"sentiment": 0.3, "news_count": 8}, ] features = builder.build(data) assert features.shape == (1, 2, 3) def test_time_alignment(self, builder): """测试时间步对齐""" data = {"sentiment": [0.1, 0.2, 0.3], "news_count": [1, 2, 3]} features = builder.build(data, time_steps=10) assert features.size(2) == 10 def test_sentiment_decay(self, builder): """测试情感衰减""" # 创建一个有明显峰值的情感序列 data = {"sentiment": [0, 0, 0, 1.0, 0, 0, 0], "news_count": [1] * 7} features = builder.build(data) # 衰减后的值应该逐渐减小 sentiment = features[0, 0, :] assert sentiment[4] < sentiment[3] # 峰值后开始衰减 def test_combine_with_market(self, builder): """测试与行情特征合并""" market = torch.randn(2, 4, 100) # [batch, 4 features, time] sentiment = torch.randn(2, 2, 100) # [batch, 2 features, time] combined = builder.combine_with_market(market, sentiment) assert combined.shape == (2, 6, 100) # ============================================================================ # F15: FactorEvaluator 测试 # ============================================================================ class TestFactorEvaluator: """因子评估器测试""" @pytest.fixture def evaluator(self): return FactorEvaluator() @pytest.fixture def sample_data(self): """创建示例数据""" np.random.seed(42) time_steps = 252 # 模拟收益率 returns = torch.randn(time_steps) * 0.02 # 模拟因子(与收益率有一定相关性) noise = torch.randn(time_steps) * 0.5 factor = returns + noise return factor, returns def test_evaluate_basic(self, evaluator, sample_data): """测试基础评估""" factor, returns = sample_data metrics = evaluator.evaluate(factor, returns) assert "sortino_ratio" in metrics assert "sharpe_ratio" in metrics assert "ic" in metrics assert "rank_ic" in metrics assert "max_drawdown" in metrics assert "turnover" in metrics def test_evaluate_batch(self, evaluator): """测试批量评估""" factor = torch.randn(10, 100) returns = torch.randn(10, 100) * 0.02 metrics = evaluator.evaluate(factor, returns) # 应该返回平均值和标准差 assert "sortino_ratio" in metrics assert "sortino_ratio_std" in metrics def test_get_reward(self, evaluator, sample_data): """测试获取 RL 奖励""" factor, returns = sample_data reward = evaluator.get_reward(factor, returns) assert isinstance(reward, float) assert not np.isnan(reward) def test_good_factor_high_ic(self, evaluator): """测试好因子有较高 IC""" # 创建一个与收益率高度相关的因子 returns = torch.randn(252) * 0.02 factor = returns * 0.8 + torch.randn(252) * 0.01 # 80% 相关 metrics = evaluator.evaluate(factor, returns) # IC 应该显著为正 assert metrics["ic"] > 0.3 def test_random_factor_low_ic(self, evaluator): """测试随机因子 IC 接近 0""" returns = torch.randn(252) * 0.02 factor = torch.randn(252) # 完全随机 metrics = evaluator.evaluate(factor, returns) # IC 应该接近 0 assert abs(metrics["ic"]) < 0.3 def test_compare_factors(self, evaluator): """测试因子比较""" returns = torch.randn(252) * 0.02 # 创建不同质量的因子 good_factor = returns * 0.8 + torch.randn(252) * 0.01 bad_factor = torch.randn(252) results = evaluator.compare_factors( [good_factor, bad_factor], returns, ["good", "bad"] ) assert "good" in results assert "bad" in results assert results["good"]["ic"] > results["bad"]["ic"] def test_rank_factors(self, evaluator): """测试因子排名""" returns = torch.randn(100) * 0.02 factors = [torch.randn(100) for _ in range(5)] ranking = evaluator.rank_factors(factors, returns) assert len(ranking) == 5 # 检查是降序排列 scores = [score for _, score in ranking] assert scores == sorted(scores, reverse=True) # ============================================================================ # F16: AlphaMiningTool 测试(需要 AgenticX 依赖) # ============================================================================ class TestAlphaMiningToolImport: """AlphaMiningTool 导入测试""" def test_import_tool(self): """测试工具可导入""" try: from app.alpha_mining.tools.alpha_mining_tool import AlphaMiningTool assert AlphaMiningTool is not None except ImportError as e: # 如果 AgenticX 不可用,跳过 pytest.skip(f"AgenticX not available: {e}") def test_tool_metadata(self): """测试工具元数据""" try: from app.alpha_mining.tools.alpha_mining_tool import AlphaMiningTool tool = AlphaMiningTool() assert tool.name == "alpha_mining" assert "量化因子" in tool.description assert len(tool.parameters) > 0 except ImportError: pytest.skip("AgenticX not available") # ============================================================================ # 端到端 P1 测试 # ============================================================================ class TestP1EndToEnd: """P1 端到端测试""" def test_full_pipeline_with_real_features(self): """使用真实特征的完整流程""" # 1. 准备行情数据 dates = pd.date_range("2024-01-01", periods=252, freq="D") np.random.seed(42) market_df = pd.DataFrame({ "close": 100 * np.exp(np.cumsum(np.random.randn(252) * 0.02)), "volume": np.abs(np.random.randn(252)) * 1e6 + 1e6, "turnover": np.abs(np.random.randn(252)) * 0.05, }, index=dates) # 2. 构建行情特征 market_builder = MarketFeatureBuilder() market_features = market_builder.build(market_df) assert market_features.shape == (1, 4, 252) # 3. 准备情感数据 sentiment_data = { "sentiment": np.random.randn(252) * 0.3, "news_count": np.abs(np.random.randn(252)) * 5 + 1 } # 4. 构建情感特征 sentiment_builder = SentimentFeatureBuilder() sentiment_features = sentiment_builder.build(sentiment_data, time_steps=252) assert sentiment_features.shape == (1, 2, 252) # 5. 合并特征 combined = sentiment_builder.combine_with_market( market_features, sentiment_features ) assert combined.shape == (1, 6, 252) # 6. 导入生成器和 VM from app.alpha_mining.model.alpha_generator import AlphaGenerator from app.alpha_mining.vm.factor_vm import FactorVM config = AlphaMiningConfig(d_model=32, num_layers=1, max_seq_len=6) generator = AlphaGenerator(config=config) vm = FactorVM() # 7. 生成并执行因子 formulas, _ = generator.generate(batch_size=5, max_len=5) valid_factors = [] for formula in formulas: factor = vm.execute(formula, combined) if factor is not None and factor.std() > 1e-6: valid_factors.append(factor) # 8. 评估因子 if valid_factors: evaluator = FactorEvaluator() returns = market_features[:, 0, :] # RET 作为收益率 for factor in valid_factors: metrics = evaluator.evaluate(factor, returns) assert "sortino_ratio" in metrics print(f"\n✅ P1 End-to-end test passed!") print(f" - Market features: {market_features.shape}") print(f" - Sentiment features: {sentiment_features.shape}") print(f" - Combined features: {combined.shape}") print(f" - Valid factors generated: {len(valid_factors)}/{len(formulas)}") if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: backend/tests/test_smoke_alpha_mining.py ================================================ """ Alpha Mining 模块冒烟测试 测试覆盖: 1. DSL 操作符执行 2. 因子虚拟机(FactorVM) 3. 因子生成模型(AlphaGenerator) 4. RL 训练器(AlphaTrainer) 5. 因子评估器(FactorEvaluator) 6. REST API 端点 """ import pytest import torch import numpy as np from typing import List # 确保可以导入模块 import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "app")) class TestDSLOperators: """测试 DSL 操作符""" def test_ops_config_exists(self): """操作符配置存在""" from app.alpha_mining.dsl.ops import OPS_CONFIG, get_op_names assert len(OPS_CONFIG) == 21, f"Expected 21 operators, got {len(OPS_CONFIG)}" names = get_op_names() assert 'ADD' in names assert 'SUB' in names assert 'MUL' in names assert 'DIV' in names assert 'MA5' in names assert 'DELAY1' in names def test_arithmetic_ops(self): """算术操作符测试""" from app.alpha_mining.dsl.ops import get_op_by_name x = torch.tensor([1.0, 2.0, 3.0]) y = torch.tensor([2.0, 3.0, 4.0]) # ADD add_fn, add_arity = get_op_by_name('ADD') assert add_arity == 2 result = add_fn(x, y) assert torch.allclose(result, torch.tensor([3.0, 5.0, 7.0])) # MUL mul_fn, mul_arity = get_op_by_name('MUL') result = mul_fn(x, y) assert torch.allclose(result, torch.tensor([2.0, 6.0, 12.0])) # DIV (safe division) div_fn, _ = get_op_by_name('DIV') result = div_fn(x, y) assert result.shape == x.shape assert not torch.any(torch.isinf(result)) def test_timeseries_ops(self): """时序操作符测试""" from app.alpha_mining.dsl.ops import ts_delay, ts_mean, ts_std x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) # Delay delayed = ts_delay(x, 1) assert delayed[0, 0] == 0 # 填充 0 assert delayed[0, 1] == 1 # 原来的第一个值 # MA ma = ts_mean(x, 3) assert ma.shape == x.shape # STD std = ts_std(x, 3) assert std.shape == x.shape class TestFactorVM: """测试因子虚拟机""" @pytest.fixture def vm(self): from app.alpha_mining.vm.factor_vm import FactorVM from app.alpha_mining.dsl.vocab import DEFAULT_VOCAB return FactorVM(vocab=DEFAULT_VOCAB) @pytest.fixture def sample_features(self): """[batch=2, features=4, time=10]""" return torch.randn(2, 4, 10) def test_execute_simple_formula(self, vm, sample_features): """执行简单因子表达式""" # RET + VOL (假设 RET=0, VOL=1, ADD=某个 token) formula = [0, 1, vm.vocab.name_to_token('ADD')] result = vm.execute(formula, sample_features) assert result is not None assert result.shape == (2, 10) # [batch, time] def test_execute_invalid_formula(self, vm, sample_features): """无效表达式返回 None""" # 不完整的表达式 formula = [0] # 只有一个特征,没有操作 result = vm.execute(formula, sample_features) # 只有一个操作数,应该返回该操作数(有效) assert result is not None # 操作符参数不足 formula = [vm.vocab.name_to_token('ADD')] # 二元操作符但没有操作数 result = vm.execute(formula, sample_features) assert result is None def test_decode_formula(self, vm): """解码因子表达式为字符串""" formula = [0, 1, vm.vocab.name_to_token('ADD')] decoded = vm.decode(formula) assert decoded is not None assert 'ADD' in decoded or '+' in decoded class TestAlphaGenerator: """测试因子生成模型""" @pytest.fixture def generator(self): from app.alpha_mining.model.alpha_generator import AlphaGenerator from app.alpha_mining.dsl.vocab import DEFAULT_VOCAB from app.alpha_mining.config import AlphaMiningConfig config = AlphaMiningConfig() return AlphaGenerator(vocab=DEFAULT_VOCAB, config=config) def test_generate_batch(self, generator): """生成一批因子表达式""" formulas, log_probs = generator.generate(batch_size=5, max_len=8) assert len(formulas) == 5 for formula in formulas: assert len(formula) <= 8 assert all(isinstance(t, int) for t in formula) def test_generate_with_training(self, generator): """训练模式生成""" sequences, log_probs_list, values = generator.generate_with_training( batch_size=3, device='cpu' ) assert sequences.shape[0] == 3 assert len(log_probs_list) > 0 class TestAlphaTrainer: """测试 RL 训练器""" @pytest.fixture def trainer(self): from app.alpha_mining.model.trainer import AlphaTrainer from app.alpha_mining.config import AlphaMiningConfig config = AlphaMiningConfig() config.batch_size = 8 return AlphaTrainer(config=config) @pytest.fixture def sample_data(self): """生成样本数据""" features = torch.randn(10, 4, 50) # [samples, features, time] returns = torch.randn(10, 50) # [samples, time] return features, returns def test_train_step(self, trainer, sample_data): """单步训练测试""" features, returns = sample_data metrics = trainer.train_step(features, returns) assert 'step' in metrics assert 'loss' in metrics assert 'avg_reward' in metrics assert 'valid_ratio' in metrics assert 'best_score' in metrics def test_train_with_callback(self, trainer, sample_data): """带回调的训练测试""" features, returns = sample_data callback_results = [] def callback(metrics): callback_results.append(metrics) result = trainer.train( features=features, returns=returns, num_steps=3, progress_bar=False, step_callback=callback ) assert len(callback_results) == 3 assert 'best_score' in result assert 'best_formula_str' in result class TestFactorEvaluator: """测试因子评估器""" @pytest.fixture def evaluator(self): from app.alpha_mining.backtest.evaluator import FactorEvaluator return FactorEvaluator() def test_evaluate_factor(self, evaluator): """评估因子""" factor = torch.randn(50) # 因子值 returns = torch.randn(50) # 收益率 metrics = evaluator.evaluate(factor, returns) assert 'sortino_ratio' in metrics assert 'sharpe_ratio' in metrics assert 'ic' in metrics assert 'rank_ic' in metrics assert 'max_drawdown' in metrics assert 'turnover' in metrics assert 'win_rate' in metrics def test_get_reward(self, evaluator): """获取 RL 奖励""" factor = torch.randn(50) returns = torch.randn(50) reward = evaluator.get_reward(factor, returns) assert isinstance(reward, float) class TestVocab: """测试词汇表""" def test_vocab_initialization(self): """词汇表初始化""" from app.alpha_mining.dsl.vocab import FactorVocab, FEATURES vocab = FactorVocab() assert vocab.vocab_size > 0 assert vocab.num_features == len(FEATURES) assert vocab.num_ops > 0 def test_token_conversion(self): """Token 转换""" from app.alpha_mining.dsl.vocab import FactorVocab vocab = FactorVocab() # 特征转换 token = vocab.name_to_token('RET') name = vocab.token_to_name(token) assert name == 'RET' # 操作符转换 token = vocab.name_to_token('ADD') name = vocab.token_to_name(token) assert name == 'ADD' class TestAPIEndpoints: """测试 REST API 端点(需要 FastAPI TestClient)""" @pytest.fixture def client(self): """创建测试客户端""" try: from fastapi.testclient import TestClient from app.main import app return TestClient(app) except ImportError: pytest.skip("FastAPI TestClient not available") def test_get_operators(self, client): """获取操作符列表""" response = client.get("/api/v1/alpha-mining/operators") assert response.status_code == 200 data = response.json() assert data.get('success') is True assert 'operators' in data assert 'features' in data assert len(data['operators']) == 21 def test_get_factors_empty(self, client): """获取因子列表(空)""" response = client.get("/api/v1/alpha-mining/factors?top_k=5") assert response.status_code == 200 data = response.json() assert data.get('success') is True assert 'factors' in data def test_evaluate_factor(self, client): """评估因子表达式""" response = client.post( "/api/v1/alpha-mining/evaluate", json={"formula": "ADD(RET, VOL)"} ) assert response.status_code == 200 data = response.json() # 可能成功或失败(取决于公式解析) assert 'success' in data def test_mine_task_start(self, client): """启动挖掘任务""" response = client.post( "/api/v1/alpha-mining/mine", json={"num_steps": 5, "use_sentiment": False, "batch_size": 4} ) assert response.status_code == 200 data = response.json() assert data.get('success') is True assert 'task_id' in data class TestEdgeCases: """边界条件测试""" def test_empty_formula(self): """空表达式""" from app.alpha_mining.vm.factor_vm import FactorVM from app.alpha_mining.dsl.vocab import DEFAULT_VOCAB vm = FactorVM(vocab=DEFAULT_VOCAB) features = torch.randn(2, 4, 10) result = vm.execute([], features) assert result is None def test_constant_factor_penalty(self): """常量因子惩罚""" from app.alpha_mining.model.trainer import AlphaTrainer from app.alpha_mining.config import AlphaMiningConfig config = AlphaMiningConfig() trainer = AlphaTrainer(config=config) # 常量因子的标准差接近 0 constant_factor = torch.ones(50) assert constant_factor.std() < config.constant_threshold def test_nan_handling(self): """NaN 处理""" from app.alpha_mining.vm.factor_vm import FactorVM from app.alpha_mining.dsl.vocab import DEFAULT_VOCAB vm = FactorVM(vocab=DEFAULT_VOCAB) # 创建包含 NaN 的特征 features = torch.randn(2, 4, 10) features[0, 0, 5] = float('nan') # 执行应该处理 NaN formula = [0] # 只取第一个特征 result = vm.execute(formula, features) if result is not None: # NaN 应该被替换为 0 assert not torch.any(torch.isnan(result)) # 运行测试 if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) ================================================ FILE: deploy/Dockerfile.celery ================================================ FROM python:3.11 WORKDIR /app # 复制requirements文件和entrypoint脚本 COPY backend/requirements.txt /app/requirements.txt COPY deploy/celery-entrypoint.sh /usr/local/bin/celery-entrypoint.sh # 安装依赖(构建时安装,用于生产环境) # 注意:开发环境使用 volumes 挂载会覆盖 /app,依赖会在 entrypoint 中重新安装 RUN pip install --no-cache-dir -r requirements.txt && \ chmod +x /usr/local/bin/celery-entrypoint.sh # 设置entrypoint(用于开发环境:检查并安装依赖) ENTRYPOINT ["/usr/local/bin/celery-entrypoint.sh"] # 设置默认命令(可以被docker-compose覆盖) CMD ["celery", "-A", "app.core.celery_app", "worker", "--loglevel=info"] ================================================ FILE: deploy/celery-entrypoint.sh ================================================ #!/bin/bash set -e # 开发环境:检查依赖是否已安装(通过检查关键包) # 注意:由于 volumes 挂载会覆盖 /app,构建时安装的依赖可能不可见 # 这个脚本确保在开发环境中依赖总是可用的 CHECK_PACKAGES=("celery" "fastapi" "sqlalchemy") NEED_INSTALL=false for pkg in "${CHECK_PACKAGES[@]}"; do if ! python -c "import ${pkg}" 2>/dev/null; then NEED_INSTALL=true break fi done if [ "$NEED_INSTALL" = true ]; then echo "📦 [开发环境] 检测到依赖未安装,正在安装..." echo " 提示:这是因为 volumes 挂载覆盖了镜像中的依赖" pip install --no-cache-dir -r requirements.txt echo "✅ 依赖安装完成" else echo "✅ 依赖已存在,跳过安装" fi # 执行传入的命令 exec "$@" ================================================ FILE: deploy/docker-compose.dev.yml ================================================ version: '3.8' services: postgres: image: postgres:15-alpine container_name: finnews_postgres environment: POSTGRES_USER: finnews POSTGRES_PASSWORD: finnews_dev_password POSTGRES_DB: finnews_db ports: - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data healthcheck: test: ["CMD-SHELL", "pg_isready -U finnews -d finnews_db"] interval: 10s timeout: 5s retries: 5 networks: - finnews_network redis: image: redis:7-alpine container_name: finnews_redis ports: - "6379:6379" command: redis-server --appendonly yes volumes: - redis_data:/data healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 10s timeout: 5s retries: 5 networks: - finnews_network milvus-etcd: image: quay.io/coreos/etcd:v3.5.5 container_name: finnews_milvus_etcd environment: - ETCD_AUTO_COMPACTION_MODE=revision - ETCD_AUTO_COMPACTION_RETENTION=1000 - ETCD_QUOTA_BACKEND_BYTES=4294967296 - ETCD_SNAPSHOT_COUNT=50000 volumes: - milvus_etcd_data:/etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd healthcheck: test: ["CMD", "etcdctl", "endpoint", "health"] interval: 30s timeout: 20s retries: 3 networks: - finnews_network milvus-minio: image: minio/minio:RELEASE.2023-03-20T20-16-18Z container_name: finnews_milvus_minio environment: MINIO_ACCESS_KEY: minioadmin MINIO_SECRET_KEY: minioadmin ports: - "9001:9001" - "9000:9000" volumes: - milvus_minio_data:/minio_data command: minio server /minio_data --console-address ":9001" healthcheck: test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] interval: 30s timeout: 20s retries: 3 networks: - finnews_network milvus-standalone: image: milvusdb/milvus:v2.3.3 container_name: finnews_milvus command: ["milvus", "run", "standalone"] security_opt: - seccomp:unconfined environment: ETCD_ENDPOINTS: milvus-etcd:2379 MINIO_ADDRESS: milvus-minio:9000 volumes: - milvus_data:/var/lib/milvus healthcheck: test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] interval: 30s start_period: 90s timeout: 20s retries: 3 ports: - "19530:19530" - "9091:9091" depends_on: - milvus-etcd - milvus-minio networks: - finnews_network celery-worker: build: context: .. dockerfile: deploy/Dockerfile.celery container_name: finnews_celery_worker working_dir: /app command: celery -A app.core.celery_app worker --loglevel=info volumes: - ../backend:/app env_file: - ../backend/.env environment: - POSTGRES_USER=finnews - POSTGRES_PASSWORD=finnews_dev_password - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - POSTGRES_DB=finnews_db - REDIS_HOST=redis - REDIS_PORT=6379 - REDIS_DB=0 - NEO4J_URI=bolt://neo4j:7687 - NEO4J_USER=neo4j - NEO4J_PASSWORD=finnews_neo4j_password depends_on: postgres: condition: service_healthy redis: condition: service_healthy neo4j: condition: service_healthy networks: - finnews_network dns: - 8.8.8.8 - 8.8.4.4 restart: unless-stopped celery-beat: build: context: .. dockerfile: deploy/Dockerfile.celery container_name: finnews_celery_beat working_dir: /app command: celery -A app.core.celery_app beat --loglevel=info volumes: - ../backend:/app env_file: - ../backend/.env environment: - POSTGRES_USER=finnews - POSTGRES_PASSWORD=finnews_dev_password - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - POSTGRES_DB=finnews_db - REDIS_HOST=redis - REDIS_PORT=6379 - REDIS_DB=0 - NEO4J_URI=bolt://neo4j:7687 - NEO4J_USER=neo4j - NEO4J_PASSWORD=finnews_neo4j_password depends_on: postgres: condition: service_healthy redis: condition: service_healthy neo4j: condition: service_healthy networks: - finnews_network dns: - 8.8.8.8 - 8.8.4.4 restart: unless-stopped # Neo4j - 知识图谱数据库 neo4j: image: neo4j:5.26.0 container_name: finnews_neo4j environment: NEO4J_AUTH: neo4j/finnews_neo4j_password NEO4J_PLUGINS: '["apoc", "graph-data-science"]' NEO4J_dbms_memory_pagecache_size: 1G NEO4J_dbms_memory_heap_initial__size: 1G NEO4J_dbms_memory_heap_max__size: 2G NEO4J_apoc_export_file_enabled: 'true' NEO4J_apoc_import_file_enabled: 'true' NEO4J_apoc_import_file_use__neo4j__config: 'true' ports: - "7474:7474" # HTTP - "7687:7687" # Bolt volumes: - neo4j_data:/data - neo4j_logs:/logs - neo4j_import:/var/lib/neo4j/import - neo4j_plugins:/plugins healthcheck: test: ["CMD", "cypher-shell", "-u", "neo4j", "-p", "finnews_neo4j_password", "RETURN 1"] interval: 30s timeout: 10s retries: 3 start_period: 40s networks: - finnews_network restart: unless-stopped volumes: postgres_data: driver: local redis_data: driver: local milvus_etcd_data: driver: local milvus_minio_data: driver: local milvus_data: driver: local neo4j_data: driver: local neo4j_logs: driver: local neo4j_import: driver: local neo4j_plugins: driver: local networks: finnews_network: driver: bridge ================================================ FILE: docs/BochaAI_Web_Search_API_20251222_121535.md ================================================ # BochaAI_Web_Search_API > 来源: https://bocha-ai.feishu.cn/wiki/RXEOw02rFiwzGSkd9mUcqoeAnNK > 爬取时间: 2025-12-22 12:15:35 > 方式: 浏览器提取 --- 博查用户帮助文档 Web Search API 一、API简介 从全网搜索任何网页信息和网页链接,结果准确、摘要完整,更适合AI使用。 可配置搜索时间范围、是否显示摘要,支持按分页获取更多结果。 二、搜索结果 包括网页、图片、视频,Response格式兼容Bing Search API。 • 网页包括name、url、snippet、summary、siteName、siteIcon、datePublished等信息 • 图片包括 contentUrl、hostPageUrl、width、height等信息 三、API接口 请求方式: POST 请求地址: https://api.bochaai.com/v1/web-search 四、请求参数 | 参数 | 类型 | 必填 | 描述 | | --- | --- | --- | --- | | query | string | 是 | 搜索关键词 | | freshness | string | 否 | 搜索时间范围(noLimit, oneDay, oneWeek, oneMonth) | | count | integer | 否 | 返回结果数量(默认10,最大50) | | offset | integer | 否 | 偏移量 | 五、响应定义 返回结果包含 webPages, images, videos 等模块。 每个网页包含 title, url, snippet, datePublished, siteName 等。 六、Python SDK 示例 ```python import requests import json url = "https://api.bochaai.com/v1/web-search" payload = json.dumps({ "query": "彩讯股份", "freshness": "oneMonth", "count": 10 }) headers = { 'Authorization': 'Bearer YOUR_API_KEY', 'Content-Type': 'application/json' } response = requests.request("POST", url, headers=headers, data=payload) print(response.text) ``` ================================================ FILE: docs/天眼查MCP服务_20260104_171528.md ================================================ # 天眼查MCP服务 > 来源: https://bigmodel.cn/marketplace/detail/1846da9039e4 > 爬取时间: 2026-01-04 17:15:28 > 方式: 浏览器提取 --- 控制台 应用空间 体验中心 开发文档 特惠专区 🔥 API Key 财务 返回广场 天眼查 全方位展示企业信息,实时监控企业风险,深挖股权关系,查询企业法律诉讼、知识产权等情况,助力识别风险。 立即体验 介绍信息 价格 工具 使用指南 介绍信息 什么是天眼查MCP服务? 天眼查 MCP(Model Context Protocol)服务,作为连接天眼查丰富数据资源与各类应用的桥梁,通过标准化接口,为用户在企业信息查询、企业风险评估、企业专利洞察等方面,提供一站式、便捷且高效的数据调用与分析解决方案。该服务依托天眼查海量数据优势,借助 MCP 协议特性,突破传统数据获取与处理瓶颈,让不同类型用户轻松获取所需企业深度信息,辅助商业决策。 支持类型:该 MCP 支持 SSE 和 Streamable 两种协议。 核心功能 (一)企业信息查询 全量工商数据:支持通过企业标识快速获取注册信息、股权结构、分支机构、变更记录等,数据源自权威平台。 多维度筛选:可按行业、注册资本、经营状态等条件精准定位目标企业。 变更轨迹追溯:记录企业工商信息变更历史,辅助分析经营战略调整。 (二)企业风险评估 全维度风险监控:实时同步法律诉讼、失信记录、行政处罚等风险数据,直连法院、工商等权威源。 风险关联分析:挖掘风险传导路径,如关联企业风险扩散、诉讼影响评估等。 实时预警推送:自定义风险类型,目标企业触发条件时即时通知。 (三)企业专利洞察 专利全要素获取:快速调取专利名称、类型、法律状态、发明人等核心字段,评估技术实力。 专利价值量化:综合引用次数、法律状态等指标量化专利资产,辅助投资与合作决策。 侵权智能预警:比对专利相似度,提前识别侵权风险,支持技术研发避坑。 如何在MCP Server上使用天眼查插件服务? MCP Server已完成天眼查插件服务的云端部署,用户操作简便。目前 MCP 服务已支持在体验中心添加使用。 支持运行 MCP 协议的客户端,如Cherry Studio、vscode等中配置,在个人中心的API Key页面复制您的 API 密钥,并按照文档内容设置服务器命令。 天眼查MCP插件服务的关键特性 海量数据支撑:整合全国乃至全球海量企业数据,构建全面企业信息库,无论是新兴创业公司,还是成熟大型集团,都能在其中找到详尽信息。 实时数据更新:与权威数据源实时对接,企业信息变更、风险事件发生、专利状态更新等,第一时间同步至系统,确保用户获取信息始终处于最新状态。 智能检索分析:支持自然语言检索,用户可用日常语言描述需求;同时,内置智能分析引擎,对查询结果进行关联分析、趋势预测等,挖掘数据深层价值。 安全可靠保障:数据传输全程加密,采用多重防护机制抵御网络攻击,确保数据隐私与服务稳定;服务器具备高并发处理能力,满足大量用户同时查询需求。 多端集成便捷:无缝对接浏览器、办公软件及各类管理系统,用户无需切换复杂系统,在日常办公环境中即可随时调用天眼查 MCP 服务,提升工作效率。 价格 工具名称 工具说明 价格 companyBaseInfo 公司名称或ID、类型、成立日期、经营状态、注册资本、 法人、工商注册号、组织机构代码、纳税人识别号等信息 0.15/次 risk 企业自身/周边/预警风险信息 0.2/次 enterprisePatent 包括专利名称、申请号、申请公布号等字段的详细信息 0.1/次 工具 companyBaseInfo 可以通过公司名称或ID获取企业基本信息,企业基本信息包括公司名称或ID、类型、成立日期、经营状态、注册资本、法人、工商注册号、统一社会信用代码、组织机构代码、纳税人识别号等字段信息 risk 可以通过关键字(公司名称、公司id、注册号或社会统一信用代码)获取企业相关天眼风险列表,包括企业自身/周边/预警风险信息 enterprisePatent 可以通过公司名称或ID获取专利的有关信息,包括专利名称、申请号、申请公布号等字段的详细信息 使用指南 天眼查MCP服务的使用场景示例 投资领域:投资人在筛选投资项目时,借助天眼查 MCP 服务,通过企业信息了解目标企业基本面,利用风险评估功能排查潜在风险,依据专利洞察判断企业创新能力与技术壁垒,综合评估投资价值与风险,辅助投资决策。 企业合作:企业寻求合作伙伴时,查询对方企业信息,明确其实力与信誉;评估合作方风险,避免合作过程中陷入法律纠纷、经营异常等陷阱;分析合作方专利布局,判断技术互补性,保障合作顺利开展。 研发创新:研发人员利用企业专利洞察功能,检索行业内相关专利,了解前沿技术动态,避免重复研发;同时,通过分析竞争对手专利,寻找技术创新突破口,优化自身研发路径。 政府招商:政府部门在招商引资过程中,借助天眼查 MCP 服务筛选拥有核心专利、具备创新实力与良好发展前景的企业;评估企业对本地产业带动价值,精准定位优质招商对象,提升招商质量与效果 。 使用教程 支持GLM文本模型API直接调用MCP 或持运行MCP协议的客户端,如Cherry Studio、Vscode、Cursor 点击获取智谱 BigModel 开放平台的API Key 在BigModel体验中心使用 目前 MCP 服务已支持在体验中心添加使用。 打开模型设置,打开MCP开关,点击添加MCP。 选择MCP,确认后,发送Prompt进行对话。 通过GLM文本模型API直接调用 cURL代码示例: curl --request POST \\ --url https://open.bigmodel.cn/api/paas/v4/chat/completions \\ --header 'Authorization: Bearer Your_Zhipu_API_Key' \\ --header 'Content-Type: application/json' \\ --data '{ "model": "glm-4.5", "do_sample": true, "stream": false, "thinking": { "type": "enabled" }, "temperature": 0.6, "top_p": 0.95, "response_format": { "type": "text" }, "messages": [ { "role": "user", "content": "帮我查询下北京天眼查科技有限公司的基本信息" } ], "tools": [ { "mcp": { "transport_type": "sse", "server_label": "tianyancha", "server_url": "https://open.bigmodel.cn/api/mcp-broker/proxy/tianyancha/sse", "headers": { "Authorization": "Bearer Your_Zhipu_API_Key" } }, "type": "mcp" } ] }' 在Cherry studio中使用 1. 在对话界面,点击MCP按钮 2. MCP服务器界面,点击添加服务器 3. 完成以下配置: 3.1 可流式传输的HTTP(streamableHttp) · URL:https://open.bigmodel.cn/api/mcp-broker/proxy/tianyancha/mcp · 请求头:Authorization = Your Zhipu API Key 3.2 服务器发送事件(sse) · URL:https://open.bigmodel.cn/api/mcp-broker/proxy/tianyancha/sse · 请求头:Authorization = Your Zhipu API Key 4. 回到对话界面,选择MCP 5. 进行模型对话,即可使用 在Cursor中使用 Cursor0.45.6版本提供了MCP功能,Cursor将作为MCP服务客户端使用的MCP服务,在Cursor中通过简单的配置就可以完成MCP服务的接入。 操作路径:Cursor设置-->【Tools&Integrations】-->【tianyancha】。 配置MCP服务器 { "mcpServers": { "tianyancha": { "url": "https://open.bigmodel.cn/api/mcp-broker/proxy/tianyancha/mcp?Authorization=Your Zhipu API Key" } } } Cursor MCP使用 Cursor MCP必须在Composer的agent模式下使用。 常见问题解答 1. BigModel上哪些类型的模型支持MCP? 实际上,MCP 是基于 Function Calling 接口来实现功能的,所以,使用 MCP 所选用的模型必须具备支持 Function Calling 的特性。Bigmodel上现在所有的语言模型(包括GLM-4-Plus、GLM-4-Flash等)均支持Function Calling。 Z1系列推理模型因为不支持 Function Calling ,无法调用MCP。 2.如何获取API Key进行调用? 前往智谱 BigModel 开放平台的 API Key 点击"添加新的API Key" Hover新添加的 API Key,点击复制ICON按钮。 3. 现在有哪些MCP是支持在Cursor中使用的? 所有Streamable的MCP,目前是可以在Cursor中调用的,其他的MCP仅支持在我们的体验中心、Cherry Studio和Vscode中使用。后续我们所有的MCP将新增Streamable类型,方便大家使用。 4.我开发了一个MCP,如何申请入驻BigModel应用空间? 你可以填写申请入驻的表单,我们将优先处理您的合作申请。 官方推荐 值得买 北京值得买科技股份有限公司 帮助用户查询商品的优惠信息、商品评测、商品概况、价格、购买渠道、性价比推荐等信息,并给出优惠商品的链接地址。 贵金属价格查询 杭州安那其科技有限公司 提供全球贵金属的实时行情、历史价格、K线走势及期货合约数据。 农产品行情数据 湖南惠农科技有限公司 全国常见农产品的行情价格数据,来自真实产地和市场用户一手行情,数据真实,可追溯。 今日油价查询 杭州玖舟数字科技有限公司 提供全国实时油价、历史价格趋势、调价预测及地区对比,助力车主、物流等场景优化加油决策与成本管理。 万物识别 北京智谱华章科技股份有限公司 万物识别MCP服务是智谱提供的基于深度学习的图像识别工具,能够快速分析图片中的地点和人物信息,支持整图及局部区域识别。 ================================================ FILE: frontend/.gitignore ================================================ # Logs logs *.log npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* lerna-debug.log* node_modules dist dist-ssr *.local # Editor directories and files .vscode/* !.vscode/extensions.json .idea .DS_Store *.suo *.ntvs* *.njsproj *.sln *.sw? ================================================ FILE: frontend/QUICKSTART.md ================================================ # FinnewsHunter Frontend 快速启动 ## 🚀 5分钟启动 ### 1. 安装依赖 ```bash npm install ``` ### 2. 配置环境变量 ```bash cp .env.example .env # 默认配置已经指向 localhost:8000,无需修改 ``` ### 3. 启动开发服务器 ```bash npm run dev ``` 访问 http://localhost:3000 ### 4. 确保后端运行 ```bash # 在另一个终端 cd ../backend uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` --- ## 📁 项目结构 - `src/pages/` - 页面组件 - `src/components/ui/` - UI 组件库 - `src/lib/` - 工具函数和 API 客户端 - `src/store/` - Zustand 全局状态 - `src/types/` - TypeScript 类型定义 --- ## ✨ 功能演示 ### 1. 首页仪表盘 - 统计卡片(总新闻数、任务数、成功率) - 最新新闻预览 ### 2. 新闻流 - 爬取新闻(可配置页码范围) - 新闻卡片展示 - 一键分析(调用 NewsAnalyst 智能体) - 情感评分展示 ### 3. 任务管理 - 实时任务列表 - 任务状态和进度 - 自动刷新(每5秒) --- ## 🛠️ 开发命令 ```bash # 开发 npm run dev # 构建 npm run build # 预览构建 npm run preview # 代码检查 npm run lint # 格式化 npm run format ``` --- **享受现代化的开发体验!🎉** ================================================ FILE: frontend/README.md ================================================ # FinnewsHunter Frontend (React + TypeScript) 现代化的金融新闻智能分析平台前端,基于 **React 18 + TypeScript + Vite + Tailwind CSS + Shadcn UI**。 ## 技术栈 - **Core**: React 18, TypeScript, Vite - **UI**: Tailwind CSS, Shadcn UI (Radix Primitives) - **State**: Zustand, TanStack Query (React Query) - **Routing**: React Router v6 - **Icons**: Lucide React - **Notifications**: Sonner ## 快速开始 ### 安装依赖 ```bash npm install # 或使用 pnpm/yarn ``` ### 开发模式 ```bash npm run dev # 访问 http://localhost:3000 ``` ### 构建生产版本 ```bash npm run build npm run preview ``` ## 项目结构 ``` src/ ├── components/ │ └── ui/ # Shadcn UI 组件 │ ├── button.tsx │ ├── card.tsx │ └── badge.tsx ├── layout/ │ └── MainLayout.tsx # 主布局(侧边栏+顶部栏) ├── pages/ │ ├── Dashboard.tsx # 首页仪表盘 │ ├── NewsListPage.tsx # 新闻流 │ ├── StockAnalysisPage.tsx # 个股分析(待实现) │ ├── AgentMonitorPage.tsx # 智能体监控(待实现) │ └── TaskManagerPage.tsx # 任务管理 ├── lib/ │ ├── api-client.ts # API 客户端 │ └── utils.ts # 工具函数 ├── store/ │ ├── useNewsStore.ts # 新闻状态 │ └── useTaskStore.ts # 任务状态 ├── types/ │ └── api.ts # TypeScript 类型定义 ├── App.tsx ├── main.tsx └── index.css ``` ## 功能特性 ### ✅ 已实现 - Dashboard 仪表盘(统计卡片) - 新闻列表展示 - 新闻爬取功能 - 智能分析按钮 - 任务管理列表 - 响应式布局 - 实时数据刷新(React Query) ### 🚧 开发中 - 个股深度分析 - K线图展示 - 智能体监控台 - WebSocket 实时推送 - 辩论可视化 ## 开发指南 ### 添加新组件 ```bash # 从 Shadcn UI 添加组件 npx shadcn-ui@latest add dialog npx shadcn-ui@latest add tabs ``` ### API 调用 ```typescript import { newsApi } from '@/lib/api-client' import { useQuery } from '@tanstack/react-query' const { data, isLoading } = useQuery({ queryKey: ['news', 'list'], queryFn: () => newsApi.getNewsList({ limit: 20 }), }) ``` ### 状态管理 ```typescript import { useNewsStore } from '@/store/useNewsStore' const { newsList, setNewsList } = useNewsStore() ``` ## 环境变量 创建 `.env.local` 文件: ``` VITE_API_BASE_URL=http://localhost:8000/api/v1 ``` ## 与后端集成 确保后端服务运行在 `http://localhost:8000`,前端会自动代理 API 请求到后端。 ## 下一步 - [ ] 实现 WebSocket 连接(实时新闻推送) - [ ] 实现个股分析页面(K线图) - [ ] 实现智能体监控台(Chain of Thought) - [ ] 实现辩论可视化(Bull vs Bear) --- **Built with ❤️ using React + AgenticX** ================================================ FILE: frontend/index.html ================================================ FinnewsHunter - 金融新闻智能分析平台
================================================ FILE: frontend/package.json ================================================ { "name": "finnews-hunter-frontend", "private": true, "version": "0.1.0", "type": "module", "scripts": { "dev": "vite", "build": "tsc && vite build", "preview": "vite preview", "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", "format": "prettier --write \"src/**/*.{ts,tsx,css}\"" }, "dependencies": { "@radix-ui/react-avatar": "^1.1.0", "@radix-ui/react-dialog": "^1.1.1", "@radix-ui/react-dropdown-menu": "^2.1.1", "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-popover": "^1.1.1", "@radix-ui/react-scroll-area": "^1.1.0", "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-separator": "^1.1.0", "@radix-ui/react-slot": "^1.1.0", "@radix-ui/react-tabs": "^1.1.0", "@radix-ui/react-tooltip": "^1.1.2", "@tanstack/react-query": "^5.28.0", "axios": "^1.6.7", "class-variance-authority": "^0.7.0", "clsx": "^2.1.0", "date-fns": "^3.3.1", "framer-motion": "^11.0.8", "lucide-react": "^0.343.0", "react": "^18.2.0", "react-dom": "^18.2.0", "react-markdown": "^9.0.1", "react-router-dom": "^6.22.2", "recharts": "^2.12.0", "klinecharts": "^9.8.10", "remark-gfm": "^4.0.1", "socket.io-client": "^4.7.4", "sonner": "^1.4.3", "tailwind-merge": "^2.2.1", "tailwindcss-animate": "^1.0.7", "zustand": "^4.5.1" }, "devDependencies": { "@types/node": "^20.11.24", "@types/react": "^18.2.61", "@types/react-dom": "^18.2.19", "@typescript-eslint/eslint-plugin": "^7.1.0", "@typescript-eslint/parser": "^7.1.0", "@vitejs/plugin-react-swc": "^3.5.0", "autoprefixer": "^10.4.18", "eslint": "^8.57.0", "eslint-plugin-react-hooks": "^4.6.0", "eslint-plugin-react-refresh": "^0.4.5", "postcss": "^8.4.35", "prettier": "^3.2.5", "tailwindcss": "^3.4.1", "typescript": "^5.3.3", "vite": "^7.2.7" } } ================================================ FILE: frontend/postcss.config.js ================================================ export default { plugins: { tailwindcss: {}, autoprefixer: {}, }, } ================================================ FILE: frontend/src/App.tsx ================================================ import { Routes, Route } from 'react-router-dom' import { Toaster } from 'sonner' import MainLayout from './layout/MainLayout' import Dashboard from './pages/Dashboard' import NewsListPage from './pages/NewsListPage' import StockSearchPage from './pages/StockSearchPage' import StockAnalysisPage from './pages/StockAnalysisPage' import AgentMonitorPage from './pages/AgentMonitorPage' import TaskManagerPage from './pages/TaskManagerPage' import AlphaMiningPage from './pages/AlphaMiningPage' function App() { return ( <> }> } /> } /> } /> } /> } /> } /> } /> ) } export default App ================================================ FILE: frontend/src/components/DebateChatRoom.tsx ================================================ import React, { useState, useRef, useEffect, useCallback } from 'react' import { Send, User, TrendingUp, TrendingDown, Briefcase, Loader2, Bot, History, Trash2, Search, ChevronDown, CheckCircle2, Clock, ListChecks, PlayCircle, XCircle } from 'lucide-react' import { Button } from '@/components/ui/button' import ReactMarkdown from 'react-markdown' import remarkGfm from 'remark-gfm' import { cn } from '@/lib/utils' import MentionInput, { MentionTarget } from './MentionInput' import type { DebateSession } from '@/store/useDebateStore' import { agentApi, SSEDebateEvent } from '@/lib/api-client' import { toast } from 'sonner' import { useGlobalI18n, useLanguageStore } from '@/store/useLanguageStore' // 消息角色类型 export type ChatRole = 'user' | 'bull' | 'bear' | 'manager' | 'system' | 'data_collector' | 'search' // 搜索计划类型 export interface SearchTask { id: string source: string query: string description: string icon: string estimated_time: number } export interface SearchPlan { plan_id: string stock_code: string stock_name: string user_query: string tasks: SearchTask[] total_estimated_time: number } // 聊天消息类型 export interface ChatMessage { id: string role: ChatRole content: string timestamp: Date round?: number isStreaming?: boolean searchPlan?: SearchPlan // 关联的搜索计划 searchStatus?: 'pending' | 'executing' | 'completed' | 'cancelled' } // 获取角色配置(支持国际化) const getRoleConfig = (t: any): Record => ({ user: { name: t.debateHistory.roleNames.user, icon: , bgColor: 'bg-blue-500', textColor: 'text-white', borderColor: 'border-blue-500', align: 'right' }, bull: { name: t.debateHistory.roleNames.bull, icon: , bgColor: 'bg-emerald-500', textColor: 'text-white', borderColor: 'border-emerald-300', align: 'left' }, bear: { name: t.debateHistory.roleNames.bear, icon: , bgColor: 'bg-rose-500', textColor: 'text-white', borderColor: 'border-rose-300', align: 'left' }, manager: { name: t.debateHistory.roleNames.manager, icon: , bgColor: 'bg-indigo-500', textColor: 'text-white', borderColor: 'border-indigo-300', align: 'left' }, data_collector: { name: t.debateHistory.roleNames.data_collector, icon: , bgColor: 'bg-purple-500', textColor: 'text-white', borderColor: 'border-purple-300', align: 'left' }, system: { name: 'System', icon: , bgColor: 'bg-gray-400', textColor: 'text-white', borderColor: 'border-gray-200', align: 'left' }, search: { name: 'Search Results', icon: , bgColor: 'bg-cyan-500', textColor: 'text-white', borderColor: 'border-cyan-300', align: 'left' } }) interface DebateChatRoomProps { messages: ChatMessage[] onSendMessage: (content: string, mentions?: MentionTarget[]) => void isDebating: boolean currentRound?: { round: number; maxRounds: number } | null activeAgent?: string | null stockName?: string disabled?: boolean // 历史相关 historySessions?: DebateSession[] onLoadSession?: (sessionId: string) => void onClearHistory?: () => void showHistory?: boolean // 搜索计划相关 onConfirmSearch?: (plan: SearchPlan, msgId: string) => void onCancelSearch?: (msgId: string) => void } // 搜索计划展示组件 const SearchPlanCard: React.FC<{ plan: SearchPlan, status: string, onConfirm: (plan: SearchPlan) => void, onCancel: () => void }> = ({ plan, status, onConfirm, onCancel }) => { const t = useGlobalI18n() const isPending = status === 'pending' const isExecuting = status === 'executing' return (

📋 {t.debateRoom.searchPlanConfirm}

{plan.tasks.map((task, index) => (
{task.icon || '🔍'}

{index + 1}. {task.description}

{t.debateRoom.roundPrefix === '第' ? '关键词' : 'Keyword'}: "{task.query}"

))}
{t.debateRoom.estimatedTime}: {plan.total_estimated_time}{t.debateRoom.seconds}
{isPending && (
)} {isExecuting && (
{t.debateRoom.searchPlanExecuting}
)} {status === 'completed' && (
{t.debateRoom.searchPlanCompleted}
)}
) } // 单条消息组件 const ChatBubble: React.FC<{ message: ChatMessage, onConfirmSearch?: (plan: SearchPlan, msgId: string) => void, onCancelSearch?: (msgId: string) => void }> = ({ message, onConfirmSearch, onCancelSearch }) => { const t = useGlobalI18n() const ROLE_CONFIG = getRoleConfig(t) const config = ROLE_CONFIG[message.role] const isRight = config.align === 'right' return (
{/* 头像 */}
{config.icon}
{/* 消息体 */}
{/* 角色名称和轮次 */}
{config.name} {message.round && ( {t.debateRoom.roundPrefix}{message.round}{t.debateRoom.roundSuffix} )} {message.timestamp.toLocaleTimeString(t.debateRoom.roundPrefix === '第' ? 'zh-CN' : 'en-US', { hour: '2-digit', minute: '2-digit' })}
{/* 消息气泡 */}
{message.content ? (
{message.content} {message.isStreaming && ( )}
) : message.searchPlan ? (
{t.stockDetail.generatingSearchPlan}
) : (
{t.debateRoom.thinking}
)} {/* 搜索计划卡片 */} {message.searchPlan && ( onConfirmSearch?.(plan, message.id)} onCancel={() => onCancelSearch?.(message.id)} /> )}
) } // 系统消息组件 const SystemMessage: React.FC<{ message: ChatMessage }> = ({ message }) => (
{message.content}
) // 主组件 const DebateChatRoom: React.FC = ({ messages, onSendMessage, isDebating, currentRound, activeAgent, stockName, disabled = false, historySessions = [], onLoadSession, onClearHistory, showHistory = true, onConfirmSearch, onCancelSearch }) => { const t = useGlobalI18n() const ROLE_CONFIG = getRoleConfig(t) const [inputValue, setInputValue] = useState('') const [showHistoryDropdown, setShowHistoryDropdown] = useState(false) const [pendingMentions, setPendingMentions] = useState([]) const scrollRef = useRef(null) const historyDropdownRef = useRef(null) // 自动滚动到底部 useEffect(() => { if (scrollRef.current) { scrollRef.current.scrollTop = scrollRef.current.scrollHeight } }, [messages]) // 点击外部关闭历史下拉框 useEffect(() => { const handleClickOutside = (e: MouseEvent) => { if (historyDropdownRef.current && !historyDropdownRef.current.contains(e.target as Node)) { setShowHistoryDropdown(false) } } document.addEventListener('mousedown', handleClickOutside) return () => document.removeEventListener('mousedown', handleClickOutside) }, []) const handleSendWithMentions = useCallback((content: string, mentions: MentionTarget[]) => { if (content.trim() && !disabled && !isDebating) { onSendMessage(content.trim(), mentions) setInputValue('') setPendingMentions([]) } }, [disabled, isDebating, onSendMessage]) // 获取当前活跃角色的提示 const getActiveIndicator = () => { if (!activeAgent) return null const agentMap: Record = { 'BullResearcher': 'bull', 'BearResearcher': 'bear', 'InvestmentManager': 'manager', 'DataCollector': 'data_collector' } const role = agentMap[activeAgent] if (!role) return null const config = ROLE_CONFIG[role] return (
{config.name} {t.debateRoom.typing}
) } return (
{/* 头部 */}

{stockName ? `${stockName} ${t.debateRoom.title}` : t.debateRoom.titlePlaceholder}

{t.debateRoom.subtitle}

{/* 轮次指示器 */} {currentRound && (
{Array.from({ length: currentRound.maxRounds }, (_, i) => (
))}
{t.debateRoom.roundPrefix}{currentRound.round}{t.debateRoom.roundSuffix}
)}
{/* 消息区域 */}
{messages.length === 0 ? (

{t.debateRoom.clickStartDebate}

{t.debateRoom.canSpeakDuringDebate}

) : ( messages.map((msg) => ( msg.role === 'system' ? ( ) : ( ) )) )} {/* 输入指示器 */} {isDebating && activeAgent && (
{getActiveIndicator()}
)}
{/* 输入区域 */}
{/* 提示和历史按钮 */}
{isDebating ? (

💡 {t.debateRoom.mentionTip}

) : (

💡 {t.stockDetail.history === '历史' ? '输入 @ 可以选择智能体或数据源' : 'Enter @ to select agents or data sources'}

)} {/* 历史记录按钮 */} {showHistory && historySessions.length > 0 && (
{/* 历史下拉菜单 */} {showHistoryDropdown && (
{t.debateHistory.history} {t.stockDetail.session} {onClearHistory && ( )}
{historySessions.map((session, index) => ( ))}
)}
)}
) } export default DebateChatRoom ================================================ FILE: frontend/src/components/DebateConfig.tsx ================================================ /** * 辩论模式配置组件 * 支持选择不同的多智能体协作模式 */ import React, { useState, useEffect } from 'react' import { Settings, Zap, Theater, Rocket, Clock, Users, MessageSquare, ChevronDown, ChevronUp, Info } from 'lucide-react' import { useGlobalI18n } from '@/store/useLanguageStore' // 辩论模式类型 export interface DebateMode { id: string name: string description: string icon: string isDefault?: boolean } // 模式规则配置 export interface ModeRules { maxTime: number maxRounds?: number managerCanInterrupt?: boolean requireDataCollection?: boolean } // 可用的辩论模式(使用函数获取,支持国际化) const getDebateModes = (t: any): DebateMode[] => [ { id: 'parallel', name: t.stockDetail.parallelAnalysis, description: t.stockDetail.parallelAnalysisDesc || 'Bull/Bear parallel analysis, Investment Manager summarizes decision', icon: '⚡', isDefault: true }, { id: 'realtime_debate', name: t.stockDetail.realtimeDebate, description: t.stockDetail.realtimeDebateDesc || 'Four agents real-time dialogue, Investment Manager moderates, Bull/Bear alternate', icon: '🎭' }, { id: 'quick_analysis', name: t.stockDetail.quickAnalysis, description: t.stockDetail.quickAnalysisDesc || 'Single analyst quick recommendation, suitable for time-sensitive scenarios', icon: '🚀' } ] // 默认规则配置 const DEFAULT_RULES: Record = { parallel: { maxTime: 300, maxRounds: 1, managerCanInterrupt: false, requireDataCollection: false }, realtime_debate: { maxTime: 600, maxRounds: 5, managerCanInterrupt: true, requireDataCollection: true }, quick_analysis: { maxTime: 60, maxRounds: 1, managerCanInterrupt: false, requireDataCollection: false } } interface DebateConfigProps { selectedMode: string onModeChange: (mode: string) => void rules?: ModeRules onRulesChange?: (rules: ModeRules) => void disabled?: boolean compact?: boolean } export const DebateConfig: React.FC = ({ selectedMode, onModeChange, rules, onRulesChange, disabled = false, compact = false }) => { const t = useGlobalI18n() const DEBATE_MODES = getDebateModes(t) const [showAdvanced, setShowAdvanced] = useState(false) const [localRules, setLocalRules] = useState( rules || DEFAULT_RULES[selectedMode] || DEFAULT_RULES.parallel ) useEffect(() => { // 模式切换时重置规则为默认值 setLocalRules(DEFAULT_RULES[selectedMode] || DEFAULT_RULES.parallel) }, [selectedMode]) const handleRuleChange = (key: keyof ModeRules, value: number | boolean) => { const newRules = { ...localRules, [key]: value } setLocalRules(newRules) onRulesChange?.(newRules) } const getModeIcon = (mode: DebateMode) => { switch (mode.id) { case 'parallel': return case 'realtime_debate': return case 'quick_analysis': return default: return } } const selectedModeData = DEBATE_MODES.find(m => m.id === selectedMode) if (compact) { return (
) } return (
{/* 模式选择 */}

{t.stockDetail.analysisModeConfig || t.stockDetail.analysisMode}

{DEBATE_MODES.map((mode) => ( ))}
{/* 模式说明 */} {selectedModeData && (
{getModeIcon(selectedModeData)}

{selectedModeData.name}

{selectedModeData.description}

{/* 模式特性标签 */}
{selectedMode === 'parallel' && ( <> {t.stockDetail.parallelExecution || 'Parallel Execution'} {t.stockDetail.about2to3min || '~2-3 min'} )} {selectedMode === 'realtime_debate' && ( <> {t.stockDetail.realtimeDialogue || 'Real-time Dialogue'} {t.stockDetail.fourAgents || '4 Agents'} {t.stockDetail.about5to10min || '~5-10 min'} )} {selectedMode === 'quick_analysis' && ( <> {t.stockDetail.singleAgent || 'Single Agent'} {t.stockDetail.about1min || '~1 min'} )}
)} {/* 高级配置 */}
{showAdvanced && (
{/* 最大时间 */}
handleRuleChange('maxTime', parseInt(e.target.value) || 300)} disabled={disabled} min={60} max={1800} step={60} className="w-20 text-sm border border-gray-200 rounded px-2 py-1 text-right disabled:bg-gray-100" /> {t.stockDetail.seconds || 's'}
{/* 实时辩论模式专属配置 */} {selectedMode === 'realtime_debate' && ( <>
handleRuleChange('maxRounds', parseInt(e.target.value) || 5)} disabled={disabled} min={1} max={10} className="w-20 text-sm border border-gray-200 rounded px-2 py-1 text-right disabled:bg-gray-100" /> {t.stockDetail.rounds || 'rounds'}
handleRuleChange('managerCanInterrupt', e.target.checked)} disabled={disabled} className="w-4 h-4 text-blue-600 border-gray-300 rounded focus:ring-blue-500 disabled:cursor-not-allowed" />
handleRuleChange('requireDataCollection', e.target.checked)} disabled={disabled} className="w-4 h-4 text-blue-600 border-gray-300 rounded focus:ring-blue-500 disabled:cursor-not-allowed" />
)}
)}
) } // 辩论模式选择器(简化版本,用于其他页面) export const DebateModeSelector: React.FC<{ value: string onChange: (value: string) => void disabled?: boolean }> = ({ value, onChange, disabled }) => { const t = useGlobalI18n() const DEBATE_MODES = getDebateModes(t) return (
{DEBATE_MODES.map((mode) => ( ))}
) } export default DebateConfig ================================================ FILE: frontend/src/components/DebateHistorySidebar.tsx ================================================ import React, { useState, useMemo } from 'react' import { History, Trash2, MessageSquare, Clock, PlayCircle, Swords, Zap, Activity, X, Search, Calendar } from 'lucide-react' import { Button } from '@/components/ui/button' import { cn } from '@/lib/utils' import type { DebateSession } from '@/store/useDebateStore' import { useGlobalI18n } from '@/store/useLanguageStore' interface DebateHistorySidebarProps { sessions: DebateSession[] currentSessionId?: string | null onLoadSession: (session: DebateSession) => void onDeleteSession?: (sessionId: string) => void onClearHistory?: () => void isOpen: boolean onToggle: () => void } // 获取模式图标和样式(支持国际化) const getModeInfo = (mode: string, t: any) => { switch (mode) { case 'parallel': return { icon: , label: t.stockDetail.parallelAnalysis, color: 'text-amber-600', bgColor: 'bg-amber-50' } case 'realtime_debate': return { icon: , label: t.stockDetail.realtimeDebate, color: 'text-purple-600', bgColor: 'bg-purple-50' } case 'quick_analysis': return { icon: , label: t.stockDetail.quickAnalysis || 'Quick Analysis', color: 'text-blue-600', bgColor: 'bg-blue-50' } default: return { icon: , label: t.stockDetail.bullBear || 'Debate', color: 'text-gray-600', bgColor: 'bg-gray-50' } } } // 格式化时间(支持国际化) const formatTime = (date: Date, t: any) => { const now = new Date() const diff = now.getTime() - date.getTime() const minutes = Math.floor(diff / 60000) const hours = Math.floor(diff / 3600000) const days = Math.floor(diff / 86400000) if (minutes < 1) return t.debateHistory.justNow if (minutes < 60) return `${minutes}${t.debateHistory.minutesAgo}` if (hours < 24) return `${hours}${t.debateHistory.hoursAgo}` if (days < 7) return `${days}${t.debateHistory.daysAgo}` return date.toLocaleDateString(t.debateHistory.justNow === '刚刚' ? 'zh-CN' : 'en-US', { month: 'short', day: 'numeric' }) } // 会话预览内容(支持国际化) const getSessionPreview = (session: DebateSession, t: any) => { if (session.messages.length === 0) { return t.debateHistory.noMessages } // 获取最后一条非系统消息 const lastMessage = [...session.messages] .reverse() .find(m => m.role !== 'system') if (lastMessage) { const roleName = t.debateHistory.roleNames[lastMessage.role] || lastMessage.role const content = lastMessage.content.slice(0, 40) return `${roleName}: ${content}${lastMessage.content.length > 40 ? '...' : ''}` } return `${session.messages.length} ${t.debateHistory.messages}` } const DebateHistorySidebar: React.FC = ({ sessions, currentSessionId, onLoadSession, onDeleteSession, onClearHistory, isOpen, onToggle }) => { const t = useGlobalI18n() const [searchTerm, setSearchTerm] = useState('') const [hoveredId, setHoveredId] = useState(null) // 过滤会话 const filteredSessions = useMemo(() => { if (!searchTerm) return sessions const term = searchTerm.toLowerCase() return sessions.filter(s => s.stockName?.toLowerCase().includes(term) || s.messages.some(m => m.content.toLowerCase().includes(term)) ) }, [sessions, searchTerm]) // 按日期分组 const groupedSessions = useMemo(() => { const groups: { label: string; sessions: DebateSession[] }[] = [] const today = new Date() today.setHours(0, 0, 0, 0) const yesterday = new Date(today) yesterday.setDate(yesterday.getDate() - 1) const weekAgo = new Date(today) weekAgo.setDate(weekAgo.getDate() - 7) const todaySessions: DebateSession[] = [] const yesterdaySessions: DebateSession[] = [] const thisWeekSessions: DebateSession[] = [] const olderSessions: DebateSession[] = [] filteredSessions.forEach(session => { const sessionDate = new Date(session.updatedAt) sessionDate.setHours(0, 0, 0, 0) if (sessionDate.getTime() === today.getTime()) { todaySessions.push(session) } else if (sessionDate.getTime() === yesterday.getTime()) { yesterdaySessions.push(session) } else if (sessionDate > weekAgo) { thisWeekSessions.push(session) } else { olderSessions.push(session) } }) if (todaySessions.length > 0) groups.push({ label: t.debateHistory.today, sessions: todaySessions }) if (yesterdaySessions.length > 0) groups.push({ label: t.debateHistory.yesterday, sessions: yesterdaySessions }) if (thisWeekSessions.length > 0) groups.push({ label: t.debateHistory.thisWeek, sessions: thisWeekSessions }) if (olderSessions.length > 0) groups.push({ label: t.debateHistory.older, sessions: olderSessions }) return groups }, [filteredSessions, t]) return ( <> {/* 折叠状态的标签按钮 */} {!isOpen && sessions.length > 0 && ( )} {/* 侧边栏面板 */}
{/* 头部 */}

{t.debateHistory.history}

{sessions.length} {t.stockDetail.session}

{/* 搜索框 */}
setSearchTerm(e.target.value)} placeholder={t.debateHistory.searchPlaceholder} className="w-full pl-9 pr-3 py-2 text-sm border border-gray-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-200 focus:border-indigo-300" />
{/* 会话列表 */}
{groupedSessions.length === 0 ? (

{searchTerm ? t.debateHistory.noMatchingRecords : t.debateHistory.noHistoryYet}

{searchTerm ? t.debateHistory.tryOtherKeywords : t.debateHistory.historyAutoSave}

) : (
{groupedSessions.map(group => (
{group.label}
{group.sessions.map(session => { const modeInfo = getModeInfo(session.mode, t) const isActive = session.id === currentSessionId const isHovered = session.id === hoveredId return (
setHoveredId(session.id)} onMouseLeave={() => setHoveredId(null)} onClick={() => onLoadSession(session)} >
{/* 模式图标 */}
{modeInfo.icon}
{/* 会话信息 */}
{session.stockName || session.stockCode} {modeInfo.label}

{getSessionPreview(session, t)}

{session.messages.length} · {formatTime(new Date(session.updatedAt), t)}
{/* 操作按钮 */}
{onDeleteSession && ( )}
{/* 活跃指示器 */} {isActive && (
)}
) })}
))}
)}
{/* 底部操作 */} {sessions.length > 0 && onClearHistory && (
)}
{/* 遮罩层 */} {isOpen && (
)} ) } export default DebateHistorySidebar ================================================ FILE: frontend/src/components/HighlightText.tsx ================================================ import React from 'react' interface HighlightTextProps { text: string highlight: string className?: string } /** * HighlightText 组件 * * 用于在文本中高亮显示指定的关键词 * * @param text - 原始文本 * @param highlight - 需要高亮的关键词 * @param className - 应用到容器的 CSS 类名 * * @example * */ export default function HighlightText({ text, highlight, className = '' }: HighlightTextProps) { // 如果没有高亮词,直接返回原文本 if (!highlight || !highlight.trim()) { return {text} } // 转义特殊正则字符,避免正则表达式错误 const escapeRegExp = (str: string) => { return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') } try { // 使用正则表达式分割文本,保留匹配部分 const escapedHighlight = escapeRegExp(highlight.trim()) const parts = text.split(new RegExp(`(${escapedHighlight})`, 'gi')) return ( {parts.map((part, index) => { // 判断是否为匹配的关键词(不区分大小写) const isMatch = part.toLowerCase() === highlight.toLowerCase() return isMatch ? ( {part} ) : ( {part} ) })} ) } catch (error) { // 如果正则表达式出错,返回原文本 console.error('HighlightText error:', error) return {text} } } ================================================ FILE: frontend/src/components/KLineChart.tsx ================================================ /** * KLineChart 组件 * 使用 klinecharts 库展示专业的 K 线图 * 支持:蜡烛图、成交量、MA均线、MACD等 */ import { useEffect, useRef, useCallback, useState } from 'react' import { init, dispose, registerLocale } from 'klinecharts' import type { Chart } from 'klinecharts' import type { KLineDataPoint } from '@/types/api' import { cn } from '@/lib/utils' import { useLanguageStore } from '@/store/useLanguageStore' // 注册语言包(使用动态语言) const registerKLineLocales = () => { const { lang } = useLanguageStore.getState(); const t = globalI18n[lang]; registerLocale('zh-CN', { time: `${t.stockDetail.timeLabel}:`, open: `${t.stockDetail.openLabel}:`, high: `${t.stockDetail.highLabel}:`, low: `${t.stockDetail.lowLabel}:`, close: `${t.stockDetail.closeLabel}:`, volume: `${t.stockDetail.volumeLabel}:`, turnover: '额:', change: '涨跌:', }) registerLocale('en-US', { time: `${t.stockDetail.timeLabel}: `, open: `${t.stockDetail.openLabel}: `, high: `${t.stockDetail.highLabel}: `, low: `${t.stockDetail.lowLabel}: `, close: `${t.stockDetail.closeLabel}: `, volume: `${t.stockDetail.volumeLabel}: `, turnover: 'Turnover: ', change: 'Change: ', }) } // 初始化注册 registerLocale('zh-CN', { time: '时间:', open: '开:', high: '高:', low: '低:', close: '收:', volume: '量:', turnover: '额:', change: '涨跌:', }) registerLocale('en-US', { time: 'Time: ', open: 'Open: ', high: 'High: ', low: 'Low: ', close: 'Close: ', volume: 'Volume: ', turnover: 'Turnover: ', change: 'Change: ', }) interface KLineChartProps { data: KLineDataPoint[] height?: number className?: string showVolume?: boolean showMA?: boolean showMACD?: boolean theme?: 'light' | 'dark' period?: 'daily' | '1m' | '5m' | '15m' | '30m' | '60m' // 添加周期参数 } export default function KLineChart({ data, height = 500, className, showVolume = true, showMA = true, showMACD = false, theme = 'light', period = 'daily', }: KLineChartProps) { const { lang } = useLanguageStore() const containerRef = useRef(null) const chartRef = useRef(null) const [isInitialized, setIsInitialized] = useState(false) // 转换数据格式 - klinecharts 需要的格式 const formatData = useCallback((rawData: KLineDataPoint[]) => { return rawData.map((item) => ({ timestamp: item.timestamp, open: item.open, high: item.high, low: item.low, close: item.close, volume: item.volume, turnover: item.turnover, })) }, []) // 初始化图表 useEffect(() => { if (!containerRef.current) return // 重置初始化状态 setIsInitialized(false) // 销毁旧图表 if (chartRef.current) { dispose(chartRef.current) chartRef.current = null } // 中国 A 股风格样式:红涨绿跌 const styles = { grid: { show: true, horizontal: { show: true, size: 1, color: theme === 'dark' ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.06)', style: 'dashed' as const, }, vertical: { show: true, size: 1, color: theme === 'dark' ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.06)', style: 'dashed' as const, }, }, candle: { type: 'candle_solid' as const, bar: { upColor: '#EF5350', // 红色涨 downColor: '#26A69A', // 绿色跌 noChangeColor: '#888888', upBorderColor: '#EF5350', downBorderColor: '#26A69A', noChangeBorderColor: '#888888', upWickColor: '#EF5350', downWickColor: '#26A69A', noChangeWickColor: '#888888', }, priceMark: { show: true, high: { show: true, color: theme === 'dark' ? '#D9D9D9' : '#333333', textOffset: 5, textSize: 10, textFamily: 'Helvetica Neue', textWeight: 'normal', }, low: { show: true, color: theme === 'dark' ? '#D9D9D9' : '#333333', textOffset: 5, textSize: 10, textFamily: 'Helvetica Neue', textWeight: 'normal', }, last: { show: true, upColor: '#EF5350', downColor: '#26A69A', noChangeColor: '#888888', line: { show: true, style: 'dashed' as const, dashedValue: [4, 4], size: 1, }, text: { show: true, style: 'fill' as const, size: 12, paddingLeft: 4, paddingTop: 4, paddingRight: 4, paddingBottom: 4, borderColor: 'transparent', borderSize: 0, borderRadius: 2, color: '#FFFFFF', family: 'Helvetica Neue', weight: 'normal', }, }, }, tooltip: { showRule: 'always' as const, showType: 'standard' as const, }, }, indicator: { ohlc: { upColor: '#EF5350', downColor: '#26A69A', noChangeColor: '#888888', }, bars: [ { style: 'fill' as const, borderStyle: 'solid' as const, borderSize: 1, borderDashedValue: [2, 2], upColor: 'rgba(239, 83, 80, 0.7)', downColor: 'rgba(38, 166, 154, 0.7)', noChangeColor: '#888888', }, ], lines: [ { style: 'solid' as const, smooth: false, size: 1, dashedValue: [2, 2], color: '#FF9600' }, { style: 'solid' as const, smooth: false, size: 1, dashedValue: [2, 2], color: '#9D65C9' }, { style: 'solid' as const, smooth: false, size: 1, dashedValue: [2, 2], color: '#2196F3' }, { style: 'solid' as const, smooth: false, size: 1, dashedValue: [2, 2], color: '#E91E63' }, { style: 'solid' as const, smooth: false, size: 1, dashedValue: [2, 2], color: '#00BCD4' }, ], }, xAxis: { show: true, size: 'auto' as const, axisLine: { show: true, color: theme === 'dark' ? 'rgba(255,255,255,0.15)' : 'rgba(0,0,0,0.1)', size: 1, }, tickText: { show: true, color: theme === 'dark' ? '#D9D9D9' : '#666666', family: 'Helvetica Neue', weight: 'normal', size: 11, }, tickLine: { show: true, size: 1, length: 3, color: theme === 'dark' ? 'rgba(255,255,255,0.15)' : 'rgba(0,0,0,0.1)', }, }, yAxis: { show: true, size: 'auto' as const, position: 'right' as const, type: 'normal' as const, inside: false, reverse: false, axisLine: { show: true, color: theme === 'dark' ? 'rgba(255,255,255,0.15)' : 'rgba(0,0,0,0.1)', size: 1, }, tickText: { show: true, color: theme === 'dark' ? '#D9D9D9' : '#666666', family: 'Helvetica Neue', weight: 'normal', size: 11, }, tickLine: { show: true, size: 1, length: 3, color: theme === 'dark' ? 'rgba(255,255,255,0.15)' : 'rgba(0,0,0,0.1)', }, }, crosshair: { show: true, horizontal: { show: true, line: { show: true, style: 'dashed' as const, dashedValue: [4, 2], size: 1, color: theme === 'dark' ? 'rgba(255,255,255,0.3)' : 'rgba(0,0,0,0.2)', }, text: { show: true, style: 'fill' as const, color: '#FFFFFF', size: 12, family: 'Helvetica Neue', weight: 'normal', borderStyle: 'solid' as const, borderDashedValue: [2, 2], borderSize: 1, borderColor: theme === 'dark' ? 'rgba(255,255,255,0.15)' : 'rgba(0,0,0,0.1)', borderRadius: 2, paddingLeft: 4, paddingRight: 4, paddingTop: 2, paddingBottom: 2, backgroundColor: theme === 'dark' ? 'rgba(35,35,35,0.95)' : 'rgba(50,50,50,0.9)', }, }, vertical: { show: true, line: { show: true, style: 'dashed' as const, dashedValue: [4, 2], size: 1, color: theme === 'dark' ? 'rgba(255,255,255,0.3)' : 'rgba(0,0,0,0.2)', }, text: { show: true, style: 'fill' as const, color: '#FFFFFF', size: 12, family: 'Helvetica Neue', weight: 'normal', borderStyle: 'solid' as const, borderDashedValue: [2, 2], borderSize: 1, borderColor: theme === 'dark' ? 'rgba(255,255,255,0.15)' : 'rgba(0,0,0,0.1)', borderRadius: 2, paddingLeft: 4, paddingRight: 4, paddingTop: 2, paddingBottom: 2, backgroundColor: theme === 'dark' ? 'rgba(35,35,35,0.95)' : 'rgba(50,50,50,0.9)', }, }, }, } // 创建图表 const chart = init(containerRef.current, { locale: lang === 'zh' ? 'zh-CN' : 'en-US', styles, }) if (chart) { chartRef.current = chart // 设置自定义时间格式化 chart.setCustomApi({ formatDate: (dateTimeFormat: any, timestamp: number, format: string, type: number) => { const date = new Date(timestamp) // 日线:只显示日期 if (period === 'daily') { const year = date.getFullYear() const month = String(date.getMonth() + 1).padStart(2, '0') const day = String(date.getDate()).padStart(2, '0') return `${month}-${day}` // 简化为月-日 } // 分钟线:显示月-日 时:分 const month = String(date.getMonth() + 1).padStart(2, '0') const day = String(date.getDate()).padStart(2, '0') const hours = String(date.getHours()).padStart(2, '0') const minutes = String(date.getMinutes()).padStart(2, '0') return `${month}-${day} ${hours}:${minutes}` }, }) // 设置右侧留白为最小,让 K 线尽量占满 chart.setOffsetRightDistance(20) // 先添加 MA 均线到主图(蜡烛图上叠加) if (showMA) { chart.createIndicator('MA', false, { id: 'candle_pane' }) } // 添加成交量指标 - 在独立的副图面板 if (showVolume) { chart.createIndicator('VOL') } // 添加 MACD 指标 - 在独立的副图面板 if (showMACD) { chart.createIndicator('MACD') } // 如果有数据,立即应用 if (data && data.length > 0) { try { const formattedData = formatData(data) chart.applyNewData(formattedData) } catch (error) { console.error('Failed to apply initial chart data:', error) } } setIsInitialized(true) } return () => { setIsInitialized(false) if (chartRef.current) { dispose(chartRef.current) chartRef.current = null } } }, [theme, showVolume, showMA, showMACD, period, lang, data, formatData]) // 更新数据 - 当图表初始化完成且有数据时应用 useEffect(() => { if (!chartRef.current || !isInitialized || !data || data.length === 0) return try { const formattedData = formatData(data) chartRef.current.applyNewData(formattedData) } catch (error) { console.error('Failed to apply chart data:', error) } }, [data, isInitialized, formatData]) return (
) } // 简化版迷你 K 线图组件 export function MiniKLineChart({ data, height = 150, className, }: { data: KLineDataPoint[] height?: number className?: string }) { return ( ) } ================================================ FILE: frontend/src/components/MentionInput.tsx ================================================ import React, { useState, useRef, useEffect, useCallback, useMemo } from 'react' import { TrendingUp, TrendingDown, Briefcase, Search, Database, Globe, Chrome, Bot, Hash, X } from 'lucide-react' import { cn } from '@/lib/utils' import { useGlobalI18n } from '@/store/useLanguageStore' // 可提及的目标类型 export type MentionType = 'agent' | 'source' | 'stock' export interface MentionTarget { type: MentionType id: string label: string description?: string icon: React.ReactNode color: string } // 预定义的智能体列表 const AGENTS: MentionTarget[] = [ { type: 'agent', id: 'bull', label: '多方辩手', description: '分析看多因素', icon: , color: 'text-emerald-600 bg-emerald-50' }, { type: 'agent', id: 'bear', label: '空方辩手', description: '分析看空因素', icon: , color: 'text-rose-600 bg-rose-50' }, { type: 'agent', id: 'manager', label: '投资经理', description: '综合决策', icon: , color: 'text-indigo-600 bg-indigo-50' }, { type: 'agent', id: 'data_collector', label: '数据专员', description: '收集市场数据/动态搜索', icon: , color: 'text-cyan-600 bg-cyan-50' }, ] // 预定义的数据源列表 const SOURCES: MentionTarget[] = [ { type: 'source', id: 'akshare', label: 'AkShare', description: '金融数据接口', icon: , color: 'text-blue-600 bg-blue-50' }, { type: 'source', id: 'bochaai', label: 'BochaAI', description: '实时新闻搜索', icon: , color: 'text-orange-600 bg-orange-50' }, { type: 'source', id: 'browser', label: '网页搜索', description: '多引擎网页搜索', icon: , color: 'text-green-600 bg-green-50' }, { type: 'source', id: 'kb', label: '知识库', description: '历史新闻数据', icon: , color: 'text-amber-600 bg-amber-50' }, ] // 所有可提及目标 const ALL_TARGETS = [...AGENTS, ...SOURCES] interface MentionInputProps { value: string onChange: (value: string) => void onSubmit: (value: string, mentions: MentionTarget[]) => void placeholder?: string disabled?: boolean className?: string // 可选:动态股票列表 stockOptions?: Array<{ code: string; name: string }> } const MentionInput: React.FC = ({ value, onChange, onSubmit, placeholder, disabled = false, className, stockOptions = [] }) => { const t = useGlobalI18n() const defaultPlaceholder = placeholder || t.mentionInput.placeholder const [showPopup, setShowPopup] = useState(false) const [popupPosition, setPopupPosition] = useState({ top: 0, left: 0 }) const [selectedIndex, setSelectedIndex] = useState(0) const [mentionQuery, setMentionQuery] = useState('') const [mentionStartPos, setMentionStartPos] = useState(-1) const [activeMentions, setActiveMentions] = useState([]) const inputRef = useRef(null) const popupRef = useRef(null) // 合并股票选项到目标列表 const allTargets = useMemo(() => { const stockTargets: MentionTarget[] = stockOptions.map(s => ({ type: 'stock' as MentionType, id: s.code, label: s.name, description: s.code, icon: , color: 'text-gray-600 bg-gray-50' })) return [...ALL_TARGETS, ...stockTargets] }, [stockOptions]) // 过滤后的目标列表 const filteredTargets = useMemo(() => { if (!mentionQuery) return allTargets const query = mentionQuery.toLowerCase() return allTargets.filter(t => t.label.toLowerCase().includes(query) || t.id.toLowerCase().includes(query) || t.description?.toLowerCase().includes(query) ) }, [allTargets, mentionQuery]) // 分组显示 const groupedTargets = useMemo(() => { const agents = filteredTargets.filter(t => t.type === 'agent') const sources = filteredTargets.filter(t => t.type === 'source') const stocks = filteredTargets.filter(t => t.type === 'stock') const groups: { label: string; items: MentionTarget[] }[] = [] if (agents.length > 0) groups.push({ label: t.mentionInput.agents, items: agents }) if (sources.length > 0) groups.push({ label: t.mentionInput.sources, items: sources }) if (stocks.length > 0) groups.push({ label: t.mentionInput.stocks, items: stocks.slice(0, 5) }) return groups }, [filteredTargets, t]) // 扁平化用于键盘导航 const flatTargets = useMemo(() => { return groupedTargets.flatMap(g => g.items) }, [groupedTargets]) // 处理输入变化 const handleChange = useCallback((e: React.ChangeEvent) => { const newValue = e.target.value const cursorPos = e.target.selectionStart || 0 onChange(newValue) // 检测 @ 符号 const textBeforeCursor = newValue.slice(0, cursorPos) const lastAtIndex = textBeforeCursor.lastIndexOf('@') if (lastAtIndex !== -1) { // 检查 @ 后面是否有空格(如果有,说明不是正在输入的提及) const textAfterAt = textBeforeCursor.slice(lastAtIndex + 1) if (!textAfterAt.includes(' ')) { setMentionQuery(textAfterAt) setMentionStartPos(lastAtIndex) setShowPopup(true) setSelectedIndex(0) // 计算弹窗位置 if (inputRef.current) { const rect = inputRef.current.getBoundingClientRect() setPopupPosition({ top: rect.top - 8, // 在输入框上方显示 left: rect.left }) } return } } setShowPopup(false) setMentionQuery('') setMentionStartPos(-1) }, [onChange]) // 选择提及目标 const selectTarget = useCallback((target: MentionTarget) => { if (mentionStartPos === -1) return const beforeMention = value.slice(0, mentionStartPos) const afterMention = value.slice(mentionStartPos + mentionQuery.length + 1) // +1 for @ const newValue = `${beforeMention}@${target.label} ${afterMention}` onChange(newValue) setActiveMentions(prev => [...prev, target]) setShowPopup(false) setMentionQuery('') setMentionStartPos(-1) // 聚焦回输入框 inputRef.current?.focus() }, [value, mentionStartPos, mentionQuery, onChange]) // 键盘事件处理 const handleKeyDown = useCallback((e: React.KeyboardEvent) => { if (showPopup) { switch (e.key) { case 'ArrowDown': e.preventDefault() setSelectedIndex(prev => prev < flatTargets.length - 1 ? prev + 1 : 0 ) break case 'ArrowUp': e.preventDefault() setSelectedIndex(prev => prev > 0 ? prev - 1 : flatTargets.length - 1 ) break case 'Enter': e.preventDefault() if (flatTargets[selectedIndex]) { selectTarget(flatTargets[selectedIndex]) } break case 'Escape': e.preventDefault() setShowPopup(false) break case 'Tab': e.preventDefault() if (flatTargets[selectedIndex]) { selectTarget(flatTargets[selectedIndex]) } break } } else if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault() if (value.trim()) { onSubmit(value.trim(), activeMentions) setActiveMentions([]) } } }, [showPopup, flatTargets, selectedIndex, selectTarget, value, onSubmit, activeMentions]) // 点击外部关闭弹窗 useEffect(() => { const handleClickOutside = (e: MouseEvent) => { if ( popupRef.current && !popupRef.current.contains(e.target as Node) && inputRef.current && !inputRef.current.contains(e.target as Node) ) { setShowPopup(false) } } document.addEventListener('mousedown', handleClickOutside) return () => document.removeEventListener('mousedown', handleClickOutside) }, []) // 滚动选中项到可见区域 useEffect(() => { if (showPopup && popupRef.current) { const selectedElement = popupRef.current.querySelector(`[data-index="${selectedIndex}"]`) selectedElement?.scrollIntoView({ block: 'nearest' }) } }, [selectedIndex, showPopup]) // 移除已添加的提及标签 const removeMention = useCallback((targetId: string) => { const target = activeMentions.find(m => m.id === targetId) if (target) { const newValue = value.replace(`@${target.label}`, '').replace(/\s+/g, ' ').trim() onChange(newValue) setActiveMentions(prev => prev.filter(m => m.id !== targetId)) } }, [activeMentions, value, onChange]) return (
{/* 已选择的提及标签 */} {activeMentions.length > 0 && (
{activeMentions.map(mention => ( {mention.icon} {mention.label} ))}
)} {/* 输入框 */} {/* @ 提及弹窗 */} {showPopup && filteredTargets.length > 0 && (
使用 ↑↓ 选择,Enter 确认,Esc 取消
{groupedTargets.map((group, groupIndex) => (
0 ? 'mt-2' : ''}>
{group.label}
{group.items.map((target, itemIndex) => { const flatIndex = groupedTargets .slice(0, groupIndex) .reduce((acc, g) => acc + g.items.length, 0) + itemIndex return ( ) })}
))}
)} {/* 空结果提示 */} {showPopup && filteredTargets.length === 0 && (
未找到匹配的选项
)}
) } export default MentionInput export { AGENTS, SOURCES, ALL_TARGETS } ================================================ FILE: frontend/src/components/ModelSelector.tsx ================================================ import { useState, useEffect, useMemo } from 'react' import { useQuery } from '@tanstack/react-query' import { Button } from '@/components/ui/button' import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator, DropdownMenuTrigger, } from '@/components/ui/dropdown-menu' import { ChevronDown, Check, AlertCircle } from 'lucide-react' import { cn } from '@/lib/utils' import { llmApi } from '@/lib/api-client' import { useGlobalI18n, useLanguageStore } from '@/store/useLanguageStore' // 模型配置 export interface ModelConfig { provider: string model: string } // Provider 和 Model 的国际化映射 const PROVIDER_I18N: Record = { bailian: { labelZh: '百炼(阿里云)', labelEn: 'Bailian (Alibaba Cloud)', }, openai: { labelZh: 'OpenAI', labelEn: 'OpenAI', }, deepseek: { labelZh: 'DeepSeek', labelEn: 'DeepSeek', }, kimi: { labelZh: 'Kimi (Moonshot)', labelEn: 'Kimi (Moonshot)', }, zhipu: { labelZh: '智谱', labelEn: 'Zhipu', }, } const MODEL_DESCRIPTION_I18N: Record = { bailian: { descZh: '百炼 模型', descEn: 'Bailian Model', }, openai: { descZh: 'OpenAI 模型', descEn: 'OpenAI Model', }, deepseek: { descZh: 'DeepSeek 模型', descEn: 'DeepSeek Model', }, kimi: { descZh: 'Kimi 模型', descEn: 'Kimi Model', }, zhipu: { descZh: '智谱 模型', descEn: 'Zhipu Model', }, } const DEFAULT_CONFIG: ModelConfig = { provider: 'bailian', model: 'qwen-plus', } export default function ModelSelector() { const t = useGlobalI18n() const { lang } = useLanguageStore() const [config, setConfig] = useState(DEFAULT_CONFIG) // 从后端 API 动态加载可用厂商和模型 const { data: llmConfig, isLoading } = useQuery({ queryKey: ['llm-config'], queryFn: llmApi.getConfig, staleTime: 5 * 60 * 1000, // 缓存 5 分钟 retry: 1, }) // 国际化处理:将后端返回的 provider 和 model 数据转换为国际化文本 const providers = useMemo(() => { if (!llmConfig?.providers) return [] return llmConfig.providers.map(provider => { const providerI18n = PROVIDER_I18N[provider.value] || { labelZh: provider.label, labelEn: provider.label } const modelDescI18n = MODEL_DESCRIPTION_I18N[provider.value] || { descZh: `${provider.label} 模型`, descEn: `${provider.label} Model` } return { ...provider, label: lang === 'zh' ? providerI18n.labelZh : providerI18n.labelEn, models: provider.models.map(model => ({ ...model, description: lang === 'zh' ? modelDescI18n.descZh : modelDescI18n.descEn, })), } }) }, [llmConfig?.providers, lang]) // 从 localStorage 加载配置 useEffect(() => { const saved = localStorage.getItem('modelConfig') if (saved) { try { setConfig(JSON.parse(saved)) } catch (e) { console.error('Failed to load model config:', e) } } }, []) // 保存配置到 localStorage const saveConfig = (newConfig: ModelConfig) => { setConfig(newConfig) localStorage.setItem('modelConfig', JSON.stringify(newConfig)) // 触发全局事件,通知其他组件 window.dispatchEvent( new CustomEvent('model-config-changed', { detail: newConfig }) ) } const currentProvider = providers.find((p) => p.value === config.provider) const currentModel = currentProvider?.models.find( (m) => m.value === config.model ) // 加载状态 if (isLoading) { return (
) } // 无可用厂商 if (providers.length === 0) { return (
) } return (
{t.model.selectTip} {providers.map((provider) => (
{provider.icon} {provider.label} {!provider.has_api_key && ( ⚠️ {t.model.noApiKey} )}
{provider.models.map((model) => { const isActive = config.provider === provider.value && config.model === model.value return ( saveConfig({ provider: provider.value, model: model.value, }) } disabled={!provider.has_api_key} className={cn( "flex items-start gap-3 rounded-lg border border-transparent px-3 py-3 transition-colors", !provider.has_api_key && "opacity-50 cursor-not-allowed", isActive ? "border-primary/30 bg-primary/5" : "hover:bg-slate-50" )} >
{model.label} {isActive && }
{model.description}
) })}
))}
{t.model.current}:{currentProvider?.label} · {currentModel?.label}
) } // 导出 hook 供其他组件使用 export function useModelConfig() { const [config, setConfig] = useState(DEFAULT_CONFIG) useEffect(() => { // 加载配置 const saved = localStorage.getItem('modelConfig') if (saved) { try { setConfig(JSON.parse(saved)) } catch (e) { console.error('Failed to load model config:', e) } } // 监听配置变化 const handleConfigChange = (e: CustomEvent) => { setConfig(e.detail) } window.addEventListener( 'model-config-changed', handleConfigChange as EventListener ) return () => { window.removeEventListener( 'model-config-changed', handleConfigChange as EventListener ) } }, []) return config } ================================================ FILE: frontend/src/components/NewsDetailDrawer.tsx ================================================ import { useQuery } from '@tanstack/react-query' import { useState, useEffect } from 'react' import { toast } from 'sonner' import ReactMarkdown from 'react-markdown' import remarkGfm from 'remark-gfm' import { Sheet, SheetContent, SheetHeader, SheetTitle, SheetDescription, } from '@/components/ui/sheet' import { Button } from '@/components/ui/button' import { Badge } from '@/components/ui/badge' import { Card, CardContent } from '@/components/ui/card' import { newsApi, analysisApi } from '@/lib/api-client' import { formatRelativeTime } from '@/lib/utils' import { ExternalLink, Share2, Calendar, TrendingUp, CheckCircle2, XCircle, MinusCircle, Sparkles, Copy, Check, FileText, Code, } from 'lucide-react' // 新闻源配置 const NEWS_SOURCES = [ { key: 'all', name: '全部来源', icon: '📰' }, { key: 'sina', name: '新浪财经', icon: '🌐' }, { key: 'tencent', name: '腾讯财经', icon: '🐧' }, { key: 'jwview', name: '金融界', icon: '💰' }, { key: 'eeo', name: '经济观察网', icon: '📊' }, { key: 'caijing', name: '财经网', icon: '📈' }, { key: 'jingji21', name: '21经济网', icon: '📉' }, { key: 'nbd', name: '每日经济新闻', icon: '📰' }, { key: 'yicai', name: '第一财经', icon: '🎯' }, { key: '163', name: '网易财经', icon: '📧' }, { key: 'eastmoney', name: '东方财富', icon: '💎' }, ] interface NewsDetailDrawerProps { newsId: number | null open: boolean onOpenChange: (open: boolean) => void } export default function NewsDetailDrawer({ newsId, open, onOpenChange, }: NewsDetailDrawerProps) { const [analyzing, setAnalyzing] = useState(false) const [copiedId, setCopiedId] = useState(null) const [showRawHtml, setShowRawHtml] = useState(false) // 是否显示原始 HTML // 清理HTML标签并转换为Markdown const cleanMarkdown = (text: string): string => { return text // 替换HTML换行标签为Markdown换行 .replace(//gi, '\n') .replace(/
/gi, '\n') // 移除其他HTML标签 .replace(/<[^>]+>/g, '') // 清理多余空行 .replace(/\n{3,}/g, '\n\n') .trim() } // 复制功能 const handleCopy = async (text: string, analysisId: number) => { try { await navigator.clipboard.writeText(text) setCopiedId(analysisId) toast.success('已复制到剪贴板') setTimeout(() => setCopiedId(null), 2000) } catch (err) { toast.error('复制失败,请手动复制') } } // 获取新闻详情 const { data: news, isLoading } = useQuery({ queryKey: ['news', 'detail', newsId], queryFn: () => newsApi.getNewsDetail(newsId!), enabled: !!newsId && open, }) // 获取分析结果(如果已分析) const { data: analyses, refetch: refetchAnalyses } = useQuery({ queryKey: ['analysis', 'news', newsId], queryFn: () => analysisApi.getNewsAnalyses(newsId!), enabled: !!newsId && open, staleTime: 0, // 立即过期,确保每次打开都获取最新数据 }) // 获取相关新闻(同来源的其他新闻) const { data: relatedNews } = useQuery({ queryKey: ['news', 'related', newsId], queryFn: async () => { if (!news) return [] const allNews = await newsApi.getLatestNews({ source: news.source, limit: 10 }) // 排除当前新闻,返回前5条 return allNews.filter(n => n.id !== newsId).slice(0, 5) }, enabled: !!newsId && open && !!news, }) // 获取原始 HTML(仅在点击"查看原始内容"时加载) const { data: htmlData, isLoading: htmlLoading } = useQuery({ queryKey: ['news', 'html', newsId], queryFn: () => newsApi.getNewsHtml(newsId!), enabled: !!newsId && open && showRawHtml, }) // 当切换到新新闻时,重置分析状态 useEffect(() => { setAnalyzing(false) }, [newsId]) // 处理分享 const handleShare = async () => { if (!news) return const url = `${window.location.origin}/news/${news.id}` try { await navigator.clipboard.writeText(url) toast.success('链接已复制到剪贴板') } catch (err) { toast.error('复制失败,请手动复制') } } // 处理分析 const handleAnalyze = async () => { if (!newsId) return setAnalyzing(true) try { const result = await analysisApi.analyzeNews(newsId) if (result.success) { toast.success('分析完成!') // 刷新分析数据(不重载整个页面) await refetchAnalyses() } else { toast.error(result.error || '分析失败') } } catch (error) { toast.error('分析失败,请稍后重试') } finally { setAnalyzing(false) } } // 获取情感标签 const getSentimentBadge = (score: number | null) => { if (score === null) { return ( 😐 待分析 ) } if (score > 0.1) { return ( 利好 {score.toFixed(2)} ) } if (score < -0.1) { return ( 利空 {score.toFixed(2)} ) } return ( 中性 {score.toFixed(2)} ) } const sourceInfo = NEWS_SOURCES.find(s => s.key === news?.source) return ( {isLoading ? (

加载中...

) : !news ? (

新闻不存在

) : (
{/* 头部区域 */} {news.title}
{sourceInfo?.icon || '📰'} {sourceInfo?.name || news.source}
{formatRelativeTime(news.publish_time || news.created_at)}
{news.author && ( <> 作者:{news.author} )}
{/* 操作按钮栏 */}
{/* 情感分析卡片 - 优先显示最新分析结果 */} {(() => { // 优先使用最新分析记录中的评分,否则使用 news 表中的评分 const latestScore = analyses && analyses.length > 0 && analyses[0].sentiment_score !== null ? analyses[0].sentiment_score : news.sentiment_score; if (latestScore === null) return null; return (

情感分析

{getSentimentBadge(latestScore)} 评分:{latestScore.toFixed(3)}
{analyses && analyses.length > 0 && (
分析时间:{formatRelativeTime(analyses[0].created_at)}
)}
); })()} {/* 股票代码区域 */} {news.stock_codes && news.stock_codes.length > 0 && (

关联股票

{news.stock_codes.map((code) => ( {code} ))}
)} {/* 完整正文区域 */}

{showRawHtml ? : } {showRawHtml ? '原始内容' : '正文内容'}

{showRawHtml ? ( // 原始 HTML 展示区域
{htmlLoading ? (
加载原始内容中...
) : htmlData?.raw_html ? (