Repository: 666ghj/MiroFish Branch: main Commit: 1536a7933450 Files: 71 Total size: 1.2 MB Directory structure: gitextract_d857zhx1/ ├── .dockerignore ├── .github/ │ └── workflows/ │ └── docker-image.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README-EN.md ├── README.md ├── backend/ │ ├── app/ │ │ ├── __init__.py │ │ ├── api/ │ │ │ ├── __init__.py │ │ │ ├── graph.py │ │ │ ├── report.py │ │ │ └── simulation.py │ │ ├── config.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── project.py │ │ │ └── task.py │ │ ├── services/ │ │ │ ├── __init__.py │ │ │ ├── graph_builder.py │ │ │ ├── oasis_profile_generator.py │ │ │ ├── ontology_generator.py │ │ │ ├── report_agent.py │ │ │ ├── simulation_config_generator.py │ │ │ ├── simulation_ipc.py │ │ │ ├── simulation_manager.py │ │ │ ├── simulation_runner.py │ │ │ ├── text_processor.py │ │ │ ├── zep_entity_reader.py │ │ │ ├── zep_graph_memory_updater.py │ │ │ └── zep_tools.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── file_parser.py │ │ ├── llm_client.py │ │ ├── logger.py │ │ ├── retry.py │ │ └── zep_paging.py │ ├── pyproject.toml │ ├── requirements.txt │ ├── run.py │ └── scripts/ │ ├── action_logger.py │ ├── run_parallel_simulation.py │ ├── run_reddit_simulation.py │ ├── run_twitter_simulation.py │ └── test_profile_format.py ├── docker-compose.yml ├── frontend/ │ ├── .gitignore │ ├── index.html │ ├── package.json │ ├── src/ │ │ ├── App.vue │ │ ├── api/ │ │ │ ├── graph.js │ │ │ ├── index.js │ │ │ ├── report.js │ │ │ └── simulation.js │ │ ├── components/ │ │ │ ├── GraphPanel.vue │ │ │ ├── HistoryDatabase.vue │ │ │ ├── Step1GraphBuild.vue │ │ │ ├── Step2EnvSetup.vue │ │ │ ├── Step3Simulation.vue │ │ │ ├── Step4Report.vue │ │ │ └── Step5Interaction.vue │ │ ├── main.js │ │ ├── router/ │ │ │ └── index.js │ │ ├── store/ │ │ │ └── pendingUpload.js │ │ └── views/ │ │ ├── Home.vue │ │ ├── InteractionView.vue │ │ ├── MainView.vue │ │ ├── Process.vue │ │ ├── ReportView.vue │ │ ├── SimulationRunView.vue │ │ └── SimulationView.vue │ └── vite.config.js └── package.json ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ .git .github .gitignore .cursor .DS_Store .env node_modules frontend/node_modules backend/.venv .venv .python-version __pycache__ *.pyc .pytest_cache .mypy_cache .ruff_cache frontend/dist frontend/.vite backend/uploads ================================================ FILE: .github/workflows/docker-image.yml ================================================ name: Build and push Docker image on: push: tags: ["*"] workflow_dispatch: permissions: contents: read packages: write jobs: build-and-push: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Log in to GHCR uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: images: ghcr.io/${{ github.repository_owner }}/mirofish tags: | type=ref,event=tag type=sha type=raw,value=latest - name: Build and push uses: docker/build-push-action@v5 with: context: . file: ./Dockerfile push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} ================================================ FILE: .gitignore ================================================ # OS .DS_Store Thumbs.db # 环境变量(保护敏感信息) .env .env.local .env.*.local .env.development .env.test .env.production # Python __pycache__/ *.py[cod] *$py.class *.so .Python .venv/ venv/ ENV/ .eggs/ *.egg-info/ dist/ build/ # Node.js node_modules/ npm-debug.log* yarn-debug.log* yarn-error.log* # IDE .vscode/ .idea/ *.swp *.swo # 测试 .pytest_cache/ .coverage htmlcov/ # Cursor .cursor/ .claude/ # 文档与测试程序 mydoc/ mytest/ # 日志文件 backend/logs/ *.log # 上传文件 backend/uploads/ # Docker 数据 data/ ================================================ FILE: Dockerfile ================================================ FROM python:3.11 # 安装 Node.js (满足 >=18)及必要工具 RUN apt-get update \ && apt-get install -y --no-install-recommends nodejs npm \ && rm -rf /var/lib/apt/lists/* # 从 uv 官方镜像复制 uv COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/ WORKDIR /app # 先复制依赖描述文件以利用缓存 COPY package.json package-lock.json ./ COPY frontend/package.json frontend/package-lock.json ./frontend/ COPY backend/pyproject.toml backend/uv.lock ./backend/ # 安装依赖(Node + Python) RUN npm ci \ && npm ci --prefix frontend \ && cd backend && uv sync --frozen # 复制项目源码 COPY . . EXPOSE 3000 5001 # 同时启动前后端(开发模式) CMD ["npm", "run", "dev"] ================================================ FILE: LICENSE ================================================ GNU AFFERO GENERAL PUBLIC LICENSE Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU Affero General Public License is a free, copyleft license for software and other kinds of works, specifically designed to ensure cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. Developers that use our General Public Licenses protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License which gives you legal permission to copy, distribute and/or modify the software. A secondary benefit of defending all users' freedom is that improvements made in alternate versions of the program, if they receive widespread use, become available for other developers to incorporate. Many developers of free software are heartened and encouraged by the resulting cooperation. However, in the case of software used on network servers, this result may fail to come about. The GNU General Public License permits making a modified version and letting the public access it on a server without ever releasing its source code to the public. The GNU Affero General Public License is designed specifically to ensure that, in such cases, the modified source code becomes available to the community. It requires the operator of a network server to provide the source code of the modified version running there to the users of that server. Therefore, public use of a modified version, on a publicly accessible server, gives the public access to the source code of the modified version. An older license, called the Affero General Public License and published by Affero, was designed to accomplish similar goals. This is a different license, not a version of the Affero GPL, but Affero has released a new version of the Affero GPL which permits relicensing under this license. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Remote Network Interaction; Use with the GNU General Public License. Notwithstanding any other provision of this License, if you modify the Program, your modified version must prominently offer all users interacting with it remotely through a computer network (if your version supports such interaction) an opportunity to receive the Corresponding Source of your version by providing access to the Corresponding Source from a network server at no charge, through some standard or customary means of facilitating copying of software. This Corresponding Source shall include the Corresponding Source for any work covered by version 3 of the GNU General Public License that is incorporated pursuant to the following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the work with which it is combined will remain governed by version 3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU Affero General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If your software can interact with users remotely through a computer network, you should also make sure that it provides a way for users to get its source. For example, if your program is a web application, its interface could display a "Source" link that leads users to an archive of the code. There are many ways you could offer source, and different solutions will be better for different programs; see section 13 for the specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . ================================================ FILE: README-EN.md ================================================
MiroFish Logo 666ghj%2FMiroFish | Trendshift 简洁通用的群体智能引擎,预测万物
A Simple and Universal Swarm Intelligence Engine, Predicting Anything 666ghj%2MiroFish | Shanda [![GitHub Stars](https://img.shields.io/github/stars/666ghj/MiroFish?style=flat-square&color=DAA520)](https://github.com/666ghj/MiroFish/stargazers) [![GitHub Watchers](https://img.shields.io/github/watchers/666ghj/MiroFish?style=flat-square)](https://github.com/666ghj/MiroFish/watchers) [![GitHub Forks](https://img.shields.io/github/forks/666ghj/MiroFish?style=flat-square)](https://github.com/666ghj/MiroFish/network) [![Docker](https://img.shields.io/badge/Docker-Build-2496ED?style=flat-square&logo=docker&logoColor=white)](https://hub.docker.com/) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/666ghj/MiroFish) [![Discord](https://img.shields.io/badge/Discord-Join-5865F2?style=flat-square&logo=discord&logoColor=white)](http://discord.gg/ePf5aPaHnA) [![X](https://img.shields.io/badge/X-Follow-000000?style=flat-square&logo=x&logoColor=white)](https://x.com/mirofish_ai) [![Instagram](https://img.shields.io/badge/Instagram-Follow-E4405F?style=flat-square&logo=instagram&logoColor=white)](https://www.instagram.com/mirofish_ai/) [English](./README-EN.md) | [中文文档](./README.md)
## ⚡ Overview **MiroFish** is a next-generation AI prediction engine powered by multi-agent technology. By extracting seed information from the real world (such as breaking news, policy drafts, or financial signals), it automatically constructs a high-fidelity parallel digital world. Within this space, thousands of intelligent agents with independent personalities, long-term memory, and behavioral logic freely interact and undergo social evolution. You can inject variables dynamically from a "God's-eye view" to precisely deduce future trajectories — **rehearse the future in a digital sandbox, and win decisions after countless simulations**. > You only need to: Upload seed materials (data analysis reports or interesting novel stories) and describe your prediction requirements in natural language
> MiroFish will return: A detailed prediction report and a deeply interactive high-fidelity digital world ### Our Vision MiroFish is dedicated to creating a swarm intelligence mirror that maps reality. By capturing the collective emergence triggered by individual interactions, we break through the limitations of traditional prediction: - **At the Macro Level**: We are a rehearsal laboratory for decision-makers, allowing policies and public relations to be tested at zero risk - **At the Micro Level**: We are a creative sandbox for individual users — whether deducing novel endings or exploring imaginative scenarios, everything can be fun, playful, and accessible From serious predictions to playful simulations, we let every "what if" see its outcome, making it possible to predict anything. ## 🌐 Live Demo Welcome to visit our online demo environment and experience a prediction simulation on trending public opinion events we've prepared for you: [mirofish-live-demo](https://666ghj.github.io/mirofish-demo/) ## 📸 Screenshots
Screenshot 1 Screenshot 2
Screenshot 3 Screenshot 4
Screenshot 5 Screenshot 6
## 🎬 Demo Videos ### 1. Wuhan University Public Opinion Simulation + MiroFish Project Introduction
MiroFish Demo Video Click the image to watch the complete demo video for prediction using BettaFish-generated "Wuhan University Public Opinion Report"
### 2. Dream of the Red Chamber Lost Ending Simulation
MiroFish Demo Video Click the image to watch MiroFish's deep prediction of the lost ending based on hundreds of thousands of words from the first 80 chapters of "Dream of the Red Chamber"
> **Financial Prediction**, **Political News Prediction** and more examples coming soon... ## 🔄 Workflow 1. **Graph Building**: Seed extraction & Individual/collective memory injection & GraphRAG construction 2. **Environment Setup**: Entity relationship extraction & Persona generation & Agent configuration injection 3. **Simulation**: Dual-platform parallel simulation & Auto-parse prediction requirements & Dynamic temporal memory updates 4. **Report Generation**: ReportAgent with rich toolset for deep interaction with post-simulation environment 5. **Deep Interaction**: Chat with any agent in the simulated world & Interact with ReportAgent ## 🚀 Quick Start ### Option 1: Source Code Deployment (Recommended) #### Prerequisites | Tool | Version | Description | Check Installation | |------|---------|-------------|-------------------| | **Node.js** | 18+ | Frontend runtime, includes npm | `node -v` | | **Python** | ≥3.11, ≤3.12 | Backend runtime | `python --version` | | **uv** | Latest | Python package manager | `uv --version` | #### 1. Configure Environment Variables ```bash # Copy the example configuration file cp .env.example .env # Edit the .env file and fill in the required API keys ``` **Required Environment Variables:** ```env # LLM API Configuration (supports any LLM API with OpenAI SDK format) # Recommended: Alibaba Qwen-plus model via Bailian Platform: https://bailian.console.aliyun.com/ # High consumption, try simulations with fewer than 40 rounds first LLM_API_KEY=your_api_key LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus # Zep Cloud Configuration # Free monthly quota is sufficient for simple usage: https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key ``` #### 2. Install Dependencies ```bash # One-click installation of all dependencies (root + frontend + backend) npm run setup:all ``` Or install step by step: ```bash # Install Node dependencies (root + frontend) npm run setup # Install Python dependencies (backend, auto-creates virtual environment) npm run setup:backend ``` #### 3. Start Services ```bash # Start both frontend and backend (run from project root) npm run dev ``` **Service URLs:** - Frontend: `http://localhost:3000` - Backend API: `http://localhost:5001` **Start Individually:** ```bash npm run backend # Start backend only npm run frontend # Start frontend only ``` ### Option 2: Docker Deployment ```bash # 1. Configure environment variables (same as source deployment) cp .env.example .env # 2. Pull image and start docker compose up -d ``` Reads `.env` from root directory by default, maps ports `3000 (frontend) / 5001 (backend)` > Mirror address for faster pulling is provided as comments in `docker-compose.yml`, replace if needed. ## 📬 Join the Conversation
QQ Group
  The MiroFish team is recruiting full-time/internship positions. If you're interested in multi-agent simulation and LLM applications, feel free to send your resume to: **mirofish@shanda.com** ## 📄 Acknowledgments **MiroFish has received strategic support and incubation from Shanda Group!** MiroFish's simulation engine is powered by **[OASIS (Open Agent Social Interaction Simulations)](https://github.com/camel-ai/oasis)**, We sincerely thank the CAMEL-AI team for their open-source contributions! ## 📈 Project Statistics Star History Chart ================================================ FILE: README.md ================================================
MiroFish Logo 666ghj%2FMiroFish | Trendshift 简洁通用的群体智能引擎,预测万物
A Simple and Universal Swarm Intelligence Engine, Predicting Anything 666ghj%2MiroFish | Shanda [![GitHub Stars](https://img.shields.io/github/stars/666ghj/MiroFish?style=flat-square&color=DAA520)](https://github.com/666ghj/MiroFish/stargazers) [![GitHub Watchers](https://img.shields.io/github/watchers/666ghj/MiroFish?style=flat-square)](https://github.com/666ghj/MiroFish/watchers) [![GitHub Forks](https://img.shields.io/github/forks/666ghj/MiroFish?style=flat-square)](https://github.com/666ghj/MiroFish/network) [![Docker](https://img.shields.io/badge/Docker-Build-2496ED?style=flat-square&logo=docker&logoColor=white)](https://hub.docker.com/) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/666ghj/MiroFish) [![Discord](https://img.shields.io/badge/Discord-Join-5865F2?style=flat-square&logo=discord&logoColor=white)](http://discord.gg/ePf5aPaHnA) [![X](https://img.shields.io/badge/X-Follow-000000?style=flat-square&logo=x&logoColor=white)](https://x.com/mirofish_ai) [![Instagram](https://img.shields.io/badge/Instagram-Follow-E4405F?style=flat-square&logo=instagram&logoColor=white)](https://www.instagram.com/mirofish_ai/) [English](./README-EN.md) | [中文文档](./README.md)
## ⚡ 项目概述 **MiroFish** 是一款基于多智能体技术的新一代 AI 预测引擎。通过提取现实世界的种子信息(如突发新闻、政策草案、金融信号),自动构建出高保真的平行数字世界。在此空间内,成千上万个具备独立人格、长期记忆与行为逻辑的智能体进行自由交互与社会演化。你可透过「上帝视角」动态注入变量,精准推演未来走向——**让未来在数字沙盘中预演,助决策在百战模拟后胜出**。 > 你只需:上传种子材料(数据分析报告或者有趣的小说故事),并用自然语言描述预测需求
> MiroFish 将返回:一份详尽的预测报告,以及一个可深度交互的高保真数字世界 ### 我们的愿景 MiroFish 致力于打造映射现实的群体智能镜像,通过捕捉个体互动引发的群体涌现,突破传统预测的局限: - **于宏观**:我们是决策者的预演实验室,让政策与公关在零风险中试错 - **于微观**:我们是个人用户的创意沙盘,无论是推演小说结局还是探索脑洞,皆可有趣、好玩、触手可及 从严肃预测到趣味仿真,我们让每一个如果都能看见结果,让预测万物成为可能。 ## 🌐 在线体验 欢迎访问在线 Demo 演示环境,体验我们为你准备的一次关于热点舆情事件的推演预测:[mirofish-live-demo](https://666ghj.github.io/mirofish-demo/) ## 📸 系统截图
截图1 截图2
截图3 截图4
截图5 截图6
## 🎬 演示视频 ### 1. 武汉大学舆情推演预测 + MiroFish项目讲解
MiroFish Demo Video 点击图片查看使用微舆BettaFish生成的《武大舆情报告》进行预测的完整演示视频
### 2. 《红楼梦》失传结局推演预测
MiroFish Demo Video 点击图片查看基于《红楼梦》前80回数十万字,MiroFish深度预测失传结局
> **金融方向推演预测**、**时政要闻推演预测**等示例陆续更新中... ## 🔄 工作流程 1. **图谱构建**:现实种子提取 & 个体与群体记忆注入 & GraphRAG构建 2. **环境搭建**:实体关系抽取 & 人设生成 & 环境配置Agent注入仿真参数 3. **开始模拟**:双平台并行模拟 & 自动解析预测需求 & 动态更新时序记忆 4. **报告生成**:ReportAgent拥有丰富的工具集与模拟后环境进行深度交互 5. **深度互动**:与模拟世界中的任意一位进行对话 & 与ReportAgent进行对话 ## 🚀 快速开始 ### 一、源码部署(推荐) #### 前置要求 | 工具 | 版本要求 | 说明 | 安装检查 | |------|---------|------|---------| | **Node.js** | 18+ | 前端运行环境,包含 npm | `node -v` | | **Python** | ≥3.11, ≤3.12 | 后端运行环境 | `python --version` | | **uv** | 最新版 | Python 包管理器 | `uv --version` | #### 1. 配置环境变量 ```bash # 复制示例配置文件 cp .env.example .env # 编辑 .env 文件,填入必要的 API 密钥 ``` **必需的环境变量:** ```env # LLM API配置(支持 OpenAI SDK 格式的任意 LLM API) # 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/ # 注意消耗较大,可先进行小于40轮的模拟尝试 LLM_API_KEY=your_api_key LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus # Zep Cloud 配置 # 每月免费额度即可支撑简单使用:https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key ``` #### 2. 安装依赖 ```bash # 一键安装所有依赖(根目录 + 前端 + 后端) npm run setup:all ``` 或者分步安装: ```bash # 安装 Node 依赖(根目录 + 前端) npm run setup # 安装 Python 依赖(后端,自动创建虚拟环境) npm run setup:backend ``` #### 3. 启动服务 ```bash # 同时启动前后端(在项目根目录执行) npm run dev ``` **服务地址:** - 前端:`http://localhost:3000` - 后端 API:`http://localhost:5001` **单独启动:** ```bash npm run backend # 仅启动后端 npm run frontend # 仅启动前端 ``` ### 二、Docker 部署 ```bash # 1. 配置环境变量(同源码部署) cp .env.example .env # 2. 拉取镜像并启动 docker compose up -d ``` 默认会读取根目录下的 `.env`,并映射端口 `3000(前端)/5001(后端)` > 在 `docker-compose.yml` 中已通过注释提供加速镜像地址,可按需替换 ## 📬 更多交流
QQ交流群
  MiroFish团队长期招募全职/实习,如果你对多Agent应用感兴趣,欢迎投递简历至:**mirofish@shanda.com** ## 📄 致谢 **MiroFish 得到了盛大集团的战略支持和孵化!** MiroFish 的仿真引擎由 **[OASIS](https://github.com/camel-ai/oasis)** 驱动,我们衷心感谢 CAMEL-AI 团队的开源贡献! ## 📈 项目统计 Star History Chart ================================================ FILE: backend/app/__init__.py ================================================ """ MiroFish Backend - Flask应用工厂 """ import os import warnings # 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers) # 需要在所有其他导入之前设置 warnings.filterwarnings("ignore", message=".*resource_tracker.*") from flask import Flask, request from flask_cors import CORS from .config import Config from .utils.logger import setup_logger, get_logger def create_app(config_class=Config): """Flask应用工厂函数""" app = Flask(__name__) app.config.from_object(config_class) # 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式) # Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置 if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'): app.json.ensure_ascii = False # 设置日志 logger = setup_logger('mirofish') # 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次) is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' debug_mode = app.config.get('DEBUG', False) should_log_startup = not debug_mode or is_reloader_process if should_log_startup: logger.info("=" * 50) logger.info("MiroFish Backend 启动中...") logger.info("=" * 50) # 启用CORS CORS(app, resources={r"/api/*": {"origins": "*"}}) # 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程) from .services.simulation_runner import SimulationRunner SimulationRunner.register_cleanup() if should_log_startup: logger.info("已注册模拟进程清理函数") # 请求日志中间件 @app.before_request def log_request(): logger = get_logger('mirofish.request') logger.debug(f"请求: {request.method} {request.path}") if request.content_type and 'json' in request.content_type: logger.debug(f"请求体: {request.get_json(silent=True)}") @app.after_request def log_response(response): logger = get_logger('mirofish.request') logger.debug(f"响应: {response.status_code}") return response # 注册蓝图 from .api import graph_bp, simulation_bp, report_bp app.register_blueprint(graph_bp, url_prefix='/api/graph') app.register_blueprint(simulation_bp, url_prefix='/api/simulation') app.register_blueprint(report_bp, url_prefix='/api/report') # 健康检查 @app.route('/health') def health(): return {'status': 'ok', 'service': 'MiroFish Backend'} if should_log_startup: logger.info("MiroFish Backend 启动完成") return app ================================================ FILE: backend/app/api/__init__.py ================================================ """ API路由模块 """ from flask import Blueprint graph_bp = Blueprint('graph', __name__) simulation_bp = Blueprint('simulation', __name__) report_bp = Blueprint('report', __name__) from . import graph # noqa: E402, F401 from . import simulation # noqa: E402, F401 from . import report # noqa: E402, F401 ================================================ FILE: backend/app/api/graph.py ================================================ """ 图谱相关API路由 采用项目上下文机制,服务端持久化状态 """ import os import traceback import threading from flask import request, jsonify from . import graph_bp from ..config import Config from ..services.ontology_generator import OntologyGenerator from ..services.graph_builder import GraphBuilderService from ..services.text_processor import TextProcessor from ..utils.file_parser import FileParser from ..utils.logger import get_logger from ..models.task import TaskManager, TaskStatus from ..models.project import ProjectManager, ProjectStatus # 获取日志器 logger = get_logger('mirofish.api') def allowed_file(filename: str) -> bool: """检查文件扩展名是否允许""" if not filename or '.' not in filename: return False ext = os.path.splitext(filename)[1].lower().lstrip('.') return ext in Config.ALLOWED_EXTENSIONS # ============== 项目管理接口 ============== @graph_bp.route('/project/', methods=['GET']) def get_project(project_id: str): """ 获取项目详情 """ project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {project_id}" }), 404 return jsonify({ "success": True, "data": project.to_dict() }) @graph_bp.route('/project/list', methods=['GET']) def list_projects(): """ 列出所有项目 """ limit = request.args.get('limit', 50, type=int) projects = ProjectManager.list_projects(limit=limit) return jsonify({ "success": True, "data": [p.to_dict() for p in projects], "count": len(projects) }) @graph_bp.route('/project/', methods=['DELETE']) def delete_project(project_id: str): """ 删除项目 """ success = ProjectManager.delete_project(project_id) if not success: return jsonify({ "success": False, "error": f"项目不存在或删除失败: {project_id}" }), 404 return jsonify({ "success": True, "message": f"项目已删除: {project_id}" }) @graph_bp.route('/project//reset', methods=['POST']) def reset_project(project_id: str): """ 重置项目状态(用于重新构建图谱) """ project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {project_id}" }), 404 # 重置到本体已生成状态 if project.ontology: project.status = ProjectStatus.ONTOLOGY_GENERATED else: project.status = ProjectStatus.CREATED project.graph_id = None project.graph_build_task_id = None project.error = None ProjectManager.save_project(project) return jsonify({ "success": True, "message": f"项目已重置: {project_id}", "data": project.to_dict() }) # ============== 接口1:上传文件并生成本体 ============== @graph_bp.route('/ontology/generate', methods=['POST']) def generate_ontology(): """ 接口1:上传文件,分析生成本体定义 请求方式:multipart/form-data 参数: files: 上传的文件(PDF/MD/TXT),可多个 simulation_requirement: 模拟需求描述(必填) project_name: 项目名称(可选) additional_context: 额外说明(可选) 返回: { "success": true, "data": { "project_id": "proj_xxxx", "ontology": { "entity_types": [...], "edge_types": [...], "analysis_summary": "..." }, "files": [...], "total_text_length": 12345 } } """ try: logger.info("=== 开始生成本体定义 ===") # 获取参数 simulation_requirement = request.form.get('simulation_requirement', '') project_name = request.form.get('project_name', 'Unnamed Project') additional_context = request.form.get('additional_context', '') logger.debug(f"项目名称: {project_name}") logger.debug(f"模拟需求: {simulation_requirement[:100]}...") if not simulation_requirement: return jsonify({ "success": False, "error": "请提供模拟需求描述 (simulation_requirement)" }), 400 # 获取上传的文件 uploaded_files = request.files.getlist('files') if not uploaded_files or all(not f.filename for f in uploaded_files): return jsonify({ "success": False, "error": "请至少上传一个文档文件" }), 400 # 创建项目 project = ProjectManager.create_project(name=project_name) project.simulation_requirement = simulation_requirement logger.info(f"创建项目: {project.project_id}") # 保存文件并提取文本 document_texts = [] all_text = "" for file in uploaded_files: if file and file.filename and allowed_file(file.filename): # 保存文件到项目目录 file_info = ProjectManager.save_file_to_project( project.project_id, file, file.filename ) project.files.append({ "filename": file_info["original_filename"], "size": file_info["size"] }) # 提取文本 text = FileParser.extract_text(file_info["path"]) text = TextProcessor.preprocess_text(text) document_texts.append(text) all_text += f"\n\n=== {file_info['original_filename']} ===\n{text}" if not document_texts: ProjectManager.delete_project(project.project_id) return jsonify({ "success": False, "error": "没有成功处理任何文档,请检查文件格式" }), 400 # 保存提取的文本 project.total_text_length = len(all_text) ProjectManager.save_extracted_text(project.project_id, all_text) logger.info(f"文本提取完成,共 {len(all_text)} 字符") # 生成本体 logger.info("调用 LLM 生成本体定义...") generator = OntologyGenerator() ontology = generator.generate( document_texts=document_texts, simulation_requirement=simulation_requirement, additional_context=additional_context if additional_context else None ) # 保存本体到项目 entity_count = len(ontology.get("entity_types", [])) edge_count = len(ontology.get("edge_types", [])) logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型") project.ontology = { "entity_types": ontology.get("entity_types", []), "edge_types": ontology.get("edge_types", []) } project.analysis_summary = ontology.get("analysis_summary", "") project.status = ProjectStatus.ONTOLOGY_GENERATED ProjectManager.save_project(project) logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}") return jsonify({ "success": True, "data": { "project_id": project.project_id, "project_name": project.name, "ontology": project.ontology, "analysis_summary": project.analysis_summary, "files": project.files, "total_text_length": project.total_text_length } }) except Exception as e: return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 接口2:构建图谱 ============== @graph_bp.route('/build', methods=['POST']) def build_graph(): """ 接口2:根据project_id构建图谱 请求(JSON): { "project_id": "proj_xxxx", // 必填,来自接口1 "graph_name": "图谱名称", // 可选 "chunk_size": 500, // 可选,默认500 "chunk_overlap": 50 // 可选,默认50 } 返回: { "success": true, "data": { "project_id": "proj_xxxx", "task_id": "task_xxxx", "message": "图谱构建任务已启动" } } """ try: logger.info("=== 开始构建图谱 ===") # 检查配置 errors = [] if not Config.ZEP_API_KEY: errors.append("ZEP_API_KEY未配置") if errors: logger.error(f"配置错误: {errors}") return jsonify({ "success": False, "error": "配置错误: " + "; ".join(errors) }), 500 # 解析请求 data = request.get_json() or {} project_id = data.get('project_id') logger.debug(f"请求参数: project_id={project_id}") if not project_id: return jsonify({ "success": False, "error": "请提供 project_id" }), 400 # 获取项目 project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {project_id}" }), 404 # 检查项目状态 force = data.get('force', False) # 强制重新构建 if project.status == ProjectStatus.CREATED: return jsonify({ "success": False, "error": "项目尚未生成本体,请先调用 /ontology/generate" }), 400 if project.status == ProjectStatus.GRAPH_BUILDING and not force: return jsonify({ "success": False, "error": "图谱正在构建中,请勿重复提交。如需强制重建,请添加 force: true", "task_id": project.graph_build_task_id }), 400 # 如果强制重建,重置状态 if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]: project.status = ProjectStatus.ONTOLOGY_GENERATED project.graph_id = None project.graph_build_task_id = None project.error = None # 获取配置 graph_name = data.get('graph_name', project.name or 'MiroFish Graph') chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE) chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP) # 更新项目配置 project.chunk_size = chunk_size project.chunk_overlap = chunk_overlap # 获取提取的文本 text = ProjectManager.get_extracted_text(project_id) if not text: return jsonify({ "success": False, "error": "未找到提取的文本内容" }), 400 # 获取本体 ontology = project.ontology if not ontology: return jsonify({ "success": False, "error": "未找到本体定义" }), 400 # 创建异步任务 task_manager = TaskManager() task_id = task_manager.create_task(f"构建图谱: {graph_name}") logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}") # 更新项目状态 project.status = ProjectStatus.GRAPH_BUILDING project.graph_build_task_id = task_id ProjectManager.save_project(project) # 启动后台任务 def build_task(): build_logger = get_logger('mirofish.build') try: build_logger.info(f"[{task_id}] 开始构建图谱...") task_manager.update_task( task_id, status=TaskStatus.PROCESSING, message="初始化图谱构建服务..." ) # 创建图谱构建服务 builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) # 分块 task_manager.update_task( task_id, message="文本分块中...", progress=5 ) chunks = TextProcessor.split_text( text, chunk_size=chunk_size, overlap=chunk_overlap ) total_chunks = len(chunks) # 创建图谱 task_manager.update_task( task_id, message="创建Zep图谱...", progress=10 ) graph_id = builder.create_graph(name=graph_name) # 更新项目的graph_id project.graph_id = graph_id ProjectManager.save_project(project) # 设置本体 task_manager.update_task( task_id, message="设置本体定义...", progress=15 ) builder.set_ontology(graph_id, ontology) # 添加文本(progress_callback 签名是 (msg, progress_ratio)) def add_progress_callback(msg, progress_ratio): progress = 15 + int(progress_ratio * 40) # 15% - 55% task_manager.update_task( task_id, message=msg, progress=progress ) task_manager.update_task( task_id, message=f"开始添加 {total_chunks} 个文本块...", progress=15 ) episode_uuids = builder.add_text_batches( graph_id, chunks, batch_size=3, progress_callback=add_progress_callback ) # 等待Zep处理完成(查询每个episode的processed状态) task_manager.update_task( task_id, message="等待Zep处理数据...", progress=55 ) def wait_progress_callback(msg, progress_ratio): progress = 55 + int(progress_ratio * 35) # 55% - 90% task_manager.update_task( task_id, message=msg, progress=progress ) builder._wait_for_episodes(episode_uuids, wait_progress_callback) # 获取图谱数据 task_manager.update_task( task_id, message="获取图谱数据...", progress=95 ) graph_data = builder.get_graph_data(graph_id) # 更新项目状态 project.status = ProjectStatus.GRAPH_COMPLETED ProjectManager.save_project(project) node_count = graph_data.get("node_count", 0) edge_count = graph_data.get("edge_count", 0) build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") # 完成 task_manager.update_task( task_id, status=TaskStatus.COMPLETED, message="图谱构建完成", progress=100, result={ "project_id": project_id, "graph_id": graph_id, "node_count": node_count, "edge_count": edge_count, "chunk_count": total_chunks } ) except Exception as e: # 更新项目状态为失败 build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}") build_logger.debug(traceback.format_exc()) project.status = ProjectStatus.FAILED project.error = str(e) ProjectManager.save_project(project) task_manager.update_task( task_id, status=TaskStatus.FAILED, message=f"构建失败: {str(e)}", error=traceback.format_exc() ) # 启动后台线程 thread = threading.Thread(target=build_task, daemon=True) thread.start() return jsonify({ "success": True, "data": { "project_id": project_id, "task_id": task_id, "message": "图谱构建任务已启动,请通过 /task/{task_id} 查询进度" } }) except Exception as e: return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 任务查询接口 ============== @graph_bp.route('/task/', methods=['GET']) def get_task(task_id: str): """ 查询任务状态 """ task = TaskManager().get_task(task_id) if not task: return jsonify({ "success": False, "error": f"任务不存在: {task_id}" }), 404 return jsonify({ "success": True, "data": task.to_dict() }) @graph_bp.route('/tasks', methods=['GET']) def list_tasks(): """ 列出所有任务 """ tasks = TaskManager().list_tasks() return jsonify({ "success": True, "data": [t.to_dict() for t in tasks], "count": len(tasks) }) # ============== 图谱数据接口 ============== @graph_bp.route('/data/', methods=['GET']) def get_graph_data(graph_id: str): """ 获取图谱数据(节点和边) """ try: if not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) graph_data = builder.get_graph_data(graph_id) return jsonify({ "success": True, "data": graph_data }) except Exception as e: return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @graph_bp.route('/delete/', methods=['DELETE']) def delete_graph(graph_id: str): """ 删除Zep图谱 """ try: if not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) builder.delete_graph(graph_id) return jsonify({ "success": True, "message": f"图谱已删除: {graph_id}" }) except Exception as e: return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 ================================================ FILE: backend/app/api/report.py ================================================ """ Report API路由 提供模拟报告生成、获取、对话等接口 """ import os import traceback import threading from flask import request, jsonify, send_file from . import report_bp from ..config import Config from ..services.report_agent import ReportAgent, ReportManager, ReportStatus from ..services.simulation_manager import SimulationManager from ..models.project import ProjectManager from ..models.task import TaskManager, TaskStatus from ..utils.logger import get_logger logger = get_logger('mirofish.api.report') # ============== 报告生成接口 ============== @report_bp.route('/generate', methods=['POST']) def generate_report(): """ 生成模拟分析报告(异步任务) 这是一个耗时操作,接口会立即返回task_id, 使用 GET /api/report/generate/status 查询进度 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "force_regenerate": false // 可选,强制重新生成 } 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "task_id": "task_xxxx", "status": "generating", "message": "报告生成任务已启动" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 force_regenerate = data.get('force_regenerate', False) # 获取模拟信息 manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 # 检查是否已有报告 if not force_regenerate: existing_report = ReportManager.get_report_by_simulation(simulation_id) if existing_report and existing_report.status == ReportStatus.COMPLETED: return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "report_id": existing_report.report_id, "status": "completed", "message": "报告已存在", "already_generated": True } }) # 获取项目信息 project = ProjectManager.get_project(state.project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {state.project_id}" }), 404 graph_id = state.graph_id or project.graph_id if not graph_id: return jsonify({ "success": False, "error": "缺少图谱ID,请确保已构建图谱" }), 400 simulation_requirement = project.simulation_requirement if not simulation_requirement: return jsonify({ "success": False, "error": "缺少模拟需求描述" }), 400 # 提前生成 report_id,以便立即返回给前端 import uuid report_id = f"report_{uuid.uuid4().hex[:12]}" # 创建异步任务 task_manager = TaskManager() task_id = task_manager.create_task( task_type="report_generate", metadata={ "simulation_id": simulation_id, "graph_id": graph_id, "report_id": report_id } ) # 定义后台任务 def run_generate(): try: task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=0, message="初始化Report Agent..." ) # 创建Report Agent agent = ReportAgent( graph_id=graph_id, simulation_id=simulation_id, simulation_requirement=simulation_requirement ) # 进度回调 def progress_callback(stage, progress, message): task_manager.update_task( task_id, progress=progress, message=f"[{stage}] {message}" ) # 生成报告(传入预先生成的 report_id) report = agent.generate_report( progress_callback=progress_callback, report_id=report_id ) # 保存报告 ReportManager.save_report(report) if report.status == ReportStatus.COMPLETED: task_manager.complete_task( task_id, result={ "report_id": report.report_id, "simulation_id": simulation_id, "status": "completed" } ) else: task_manager.fail_task(task_id, report.error or "报告生成失败") except Exception as e: logger.error(f"报告生成失败: {str(e)}") task_manager.fail_task(task_id, str(e)) # 启动后台线程 thread = threading.Thread(target=run_generate, daemon=True) thread.start() return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "report_id": report_id, "task_id": task_id, "status": "generating", "message": "报告生成任务已启动,请通过 /api/report/generate/status 查询进度", "already_generated": False } }) except Exception as e: logger.error(f"启动报告生成任务失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('/generate/status', methods=['POST']) def get_generate_status(): """ 查询报告生成任务进度 请求(JSON): { "task_id": "task_xxxx", // 可选,generate返回的task_id "simulation_id": "sim_xxxx" // 可选,模拟ID } 返回: { "success": true, "data": { "task_id": "task_xxxx", "status": "processing|completed|failed", "progress": 45, "message": "..." } } """ try: data = request.get_json() or {} task_id = data.get('task_id') simulation_id = data.get('simulation_id') # 如果提供了simulation_id,先检查是否已有完成的报告 if simulation_id: existing_report = ReportManager.get_report_by_simulation(simulation_id) if existing_report and existing_report.status == ReportStatus.COMPLETED: return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "report_id": existing_report.report_id, "status": "completed", "progress": 100, "message": "报告已生成", "already_completed": True } }) if not task_id: return jsonify({ "success": False, "error": "请提供 task_id 或 simulation_id" }), 400 task_manager = TaskManager() task = task_manager.get_task(task_id) if not task: return jsonify({ "success": False, "error": f"任务不存在: {task_id}" }), 404 return jsonify({ "success": True, "data": task.to_dict() }) except Exception as e: logger.error(f"查询任务状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e) }), 500 # ============== 报告获取接口 ============== @report_bp.route('/', methods=['GET']) def get_report(report_id: str): """ 获取报告详情 返回: { "success": true, "data": { "report_id": "report_xxxx", "simulation_id": "sim_xxxx", "status": "completed", "outline": {...}, "markdown_content": "...", "created_at": "...", "completed_at": "..." } } """ try: report = ReportManager.get_report(report_id) if not report: return jsonify({ "success": False, "error": f"报告不存在: {report_id}" }), 404 return jsonify({ "success": True, "data": report.to_dict() }) except Exception as e: logger.error(f"获取报告失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('/by-simulation/', methods=['GET']) def get_report_by_simulation(simulation_id: str): """ 根据模拟ID获取报告 返回: { "success": true, "data": { "report_id": "report_xxxx", ... } } """ try: report = ReportManager.get_report_by_simulation(simulation_id) if not report: return jsonify({ "success": False, "error": f"该模拟暂无报告: {simulation_id}", "has_report": False }), 404 return jsonify({ "success": True, "data": report.to_dict(), "has_report": True }) except Exception as e: logger.error(f"获取报告失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('/list', methods=['GET']) def list_reports(): """ 列出所有报告 Query参数: simulation_id: 按模拟ID过滤(可选) limit: 返回数量限制(默认50) 返回: { "success": true, "data": [...], "count": 10 } """ try: simulation_id = request.args.get('simulation_id') limit = request.args.get('limit', 50, type=int) reports = ReportManager.list_reports( simulation_id=simulation_id, limit=limit ) return jsonify({ "success": True, "data": [r.to_dict() for r in reports], "count": len(reports) }) except Exception as e: logger.error(f"列出报告失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('//download', methods=['GET']) def download_report(report_id: str): """ 下载报告(Markdown格式) 返回Markdown文件 """ try: report = ReportManager.get_report(report_id) if not report: return jsonify({ "success": False, "error": f"报告不存在: {report_id}" }), 404 md_path = ReportManager._get_report_markdown_path(report_id) if not os.path.exists(md_path): # 如果MD文件不存在,生成一个临时文件 import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f: f.write(report.markdown_content) temp_path = f.name return send_file( temp_path, as_attachment=True, download_name=f"{report_id}.md" ) return send_file( md_path, as_attachment=True, download_name=f"{report_id}.md" ) except Exception as e: logger.error(f"下载报告失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('/', methods=['DELETE']) def delete_report(report_id: str): """删除报告""" try: success = ReportManager.delete_report(report_id) if not success: return jsonify({ "success": False, "error": f"报告不存在: {report_id}" }), 404 return jsonify({ "success": True, "message": f"报告已删除: {report_id}" }) except Exception as e: logger.error(f"删除报告失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== Report Agent对话接口 ============== @report_bp.route('/chat', methods=['POST']) def chat_with_report_agent(): """ 与Report Agent对话 Report Agent可以在对话中自主调用检索工具来回答问题 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "message": "请解释一下舆情走向", // 必填,用户消息 "chat_history": [ // 可选,对话历史 {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } 返回: { "success": true, "data": { "response": "Agent回复...", "tool_calls": [调用的工具列表], "sources": [信息来源] } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') message = data.get('message') chat_history = data.get('chat_history', []) if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 if not message: return jsonify({ "success": False, "error": "请提供 message" }), 400 # 获取模拟和项目信息 manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 project = ProjectManager.get_project(state.project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {state.project_id}" }), 404 graph_id = state.graph_id or project.graph_id if not graph_id: return jsonify({ "success": False, "error": "缺少图谱ID" }), 400 simulation_requirement = project.simulation_requirement or "" # 创建Agent并进行对话 agent = ReportAgent( graph_id=graph_id, simulation_id=simulation_id, simulation_requirement=simulation_requirement ) result = agent.chat(message=message, chat_history=chat_history) return jsonify({ "success": True, "data": result }) except Exception as e: logger.error(f"对话失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 报告进度与分章节接口 ============== @report_bp.route('//progress', methods=['GET']) def get_report_progress(report_id: str): """ 获取报告生成进度(实时) 返回: { "success": true, "data": { "status": "generating", "progress": 45, "message": "正在生成章节: 关键发现", "current_section": "关键发现", "completed_sections": ["执行摘要", "模拟背景"], "updated_at": "2025-12-09T..." } } """ try: progress = ReportManager.get_progress(report_id) if not progress: return jsonify({ "success": False, "error": f"报告不存在或进度信息不可用: {report_id}" }), 404 return jsonify({ "success": True, "data": progress }) except Exception as e: logger.error(f"获取报告进度失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('//sections', methods=['GET']) def get_report_sections(report_id: str): """ 获取已生成的章节列表(分章节输出) 前端可以轮询此接口获取已生成的章节内容,无需等待整个报告完成 返回: { "success": true, "data": { "report_id": "report_xxxx", "sections": [ { "filename": "section_01.md", "section_index": 1, "content": "## 执行摘要\\n\\n..." }, ... ], "total_sections": 3, "is_complete": false } } """ try: sections = ReportManager.get_generated_sections(report_id) # 获取报告状态 report = ReportManager.get_report(report_id) is_complete = report is not None and report.status == ReportStatus.COMPLETED return jsonify({ "success": True, "data": { "report_id": report_id, "sections": sections, "total_sections": len(sections), "is_complete": is_complete } }) except Exception as e: logger.error(f"获取章节列表失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('//section/', methods=['GET']) def get_single_section(report_id: str, section_index: int): """ 获取单个章节内容 返回: { "success": true, "data": { "filename": "section_01.md", "content": "## 执行摘要\\n\\n..." } } """ try: section_path = ReportManager._get_section_path(report_id, section_index) if not os.path.exists(section_path): return jsonify({ "success": False, "error": f"章节不存在: section_{section_index:02d}.md" }), 404 with open(section_path, 'r', encoding='utf-8') as f: content = f.read() return jsonify({ "success": True, "data": { "filename": f"section_{section_index:02d}.md", "section_index": section_index, "content": content } }) except Exception as e: logger.error(f"获取章节内容失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 报告状态检查接口 ============== @report_bp.route('/check/', methods=['GET']) def check_report_status(simulation_id: str): """ 检查模拟是否有报告,以及报告状态 用于前端判断是否解锁Interview功能 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "has_report": true, "report_status": "completed", "report_id": "report_xxxx", "interview_unlocked": true } } """ try: report = ReportManager.get_report_by_simulation(simulation_id) has_report = report is not None report_status = report.status.value if report else None report_id = report.report_id if report else None # 只有报告完成后才解锁interview interview_unlocked = has_report and report.status == ReportStatus.COMPLETED return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "has_report": has_report, "report_status": report_status, "report_id": report_id, "interview_unlocked": interview_unlocked } }) except Exception as e: logger.error(f"检查报告状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== Agent 日志接口 ============== @report_bp.route('//agent-log', methods=['GET']) def get_agent_log(report_id: str): """ 获取 Report Agent 的详细执行日志 实时获取报告生成过程中的每一步动作,包括: - 报告开始、规划开始/完成 - 每个章节的开始、工具调用、LLM响应、完成 - 报告完成或失败 Query参数: from_line: 从第几行开始读取(可选,默认0,用于增量获取) 返回: { "success": true, "data": { "logs": [ { "timestamp": "2025-12-13T...", "elapsed_seconds": 12.5, "report_id": "report_xxxx", "action": "tool_call", "stage": "generating", "section_title": "执行摘要", "section_index": 1, "details": { "tool_name": "insight_forge", "parameters": {...}, ... } }, ... ], "total_lines": 25, "from_line": 0, "has_more": false } } """ try: from_line = request.args.get('from_line', 0, type=int) log_data = ReportManager.get_agent_log(report_id, from_line=from_line) return jsonify({ "success": True, "data": log_data }) except Exception as e: logger.error(f"获取Agent日志失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('//agent-log/stream', methods=['GET']) def stream_agent_log(report_id: str): """ 获取完整的 Agent 日志(一次性获取全部) 返回: { "success": true, "data": { "logs": [...], "count": 25 } } """ try: logs = ReportManager.get_agent_log_stream(report_id) return jsonify({ "success": True, "data": { "logs": logs, "count": len(logs) } }) except Exception as e: logger.error(f"获取Agent日志失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 控制台日志接口 ============== @report_bp.route('//console-log', methods=['GET']) def get_console_log(report_id: str): """ 获取 Report Agent 的控制台输出日志 实时获取报告生成过程中的控制台输出(INFO、WARNING等), 这与 agent-log 接口返回的结构化 JSON 日志不同, 是纯文本格式的控制台风格日志。 Query参数: from_line: 从第几行开始读取(可选,默认0,用于增量获取) 返回: { "success": true, "data": { "logs": [ "[19:46:14] INFO: 搜索完成: 找到 15 条相关事实", "[19:46:14] INFO: 图谱搜索: graph_id=xxx, query=...", ... ], "total_lines": 100, "from_line": 0, "has_more": false } } """ try: from_line = request.args.get('from_line', 0, type=int) log_data = ReportManager.get_console_log(report_id, from_line=from_line) return jsonify({ "success": True, "data": log_data }) except Exception as e: logger.error(f"获取控制台日志失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('//console-log/stream', methods=['GET']) def stream_console_log(report_id: str): """ 获取完整的控制台日志(一次性获取全部) 返回: { "success": true, "data": { "logs": [...], "count": 100 } } """ try: logs = ReportManager.get_console_log_stream(report_id) return jsonify({ "success": True, "data": { "logs": logs, "count": len(logs) } }) except Exception as e: logger.error(f"获取控制台日志失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 工具调用接口(供调试使用)============== @report_bp.route('/tools/search', methods=['POST']) def search_graph_tool(): """ 图谱搜索工具接口(供调试使用) 请求(JSON): { "graph_id": "mirofish_xxxx", "query": "搜索查询", "limit": 10 } """ try: data = request.get_json() or {} graph_id = data.get('graph_id') query = data.get('query') limit = data.get('limit', 10) if not graph_id or not query: return jsonify({ "success": False, "error": "请提供 graph_id 和 query" }), 400 from ..services.zep_tools import ZepToolsService tools = ZepToolsService() result = tools.search_graph( graph_id=graph_id, query=query, limit=limit ) return jsonify({ "success": True, "data": result.to_dict() }) except Exception as e: logger.error(f"图谱搜索失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @report_bp.route('/tools/statistics', methods=['POST']) def get_graph_statistics_tool(): """ 图谱统计工具接口(供调试使用) 请求(JSON): { "graph_id": "mirofish_xxxx" } """ try: data = request.get_json() or {} graph_id = data.get('graph_id') if not graph_id: return jsonify({ "success": False, "error": "请提供 graph_id" }), 400 from ..services.zep_tools import ZepToolsService tools = ZepToolsService() result = tools.get_graph_statistics(graph_id) return jsonify({ "success": True, "data": result }) except Exception as e: logger.error(f"获取图谱统计失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 ================================================ FILE: backend/app/api/simulation.py ================================================ """ 模拟相关API路由 Step2: Zep实体读取与过滤、OASIS模拟准备与运行(全程自动化) """ import os import traceback from flask import request, jsonify, send_file from . import simulation_bp from ..config import Config from ..services.zep_entity_reader import ZepEntityReader from ..services.oasis_profile_generator import OasisProfileGenerator from ..services.simulation_manager import SimulationManager, SimulationStatus from ..services.simulation_runner import SimulationRunner, RunnerStatus from ..utils.logger import get_logger from ..models.project import ProjectManager logger = get_logger('mirofish.api.simulation') # Interview prompt 优化前缀 # 添加此前缀可以避免Agent调用工具,直接用文本回复 INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:" def optimize_interview_prompt(prompt: str) -> str: """ 优化Interview提问,添加前缀避免Agent调用工具 Args: prompt: 原始提问 Returns: 优化后的提问 """ if not prompt: return prompt # 避免重复添加前缀 if prompt.startswith(INTERVIEW_PROMPT_PREFIX): return prompt return f"{INTERVIEW_PROMPT_PREFIX}{prompt}" # ============== 实体读取接口 ============== @simulation_bp.route('/entities/', methods=['GET']) def get_graph_entities(graph_id: str): """ 获取图谱中的所有实体(已过滤) 只返回符合预定义实体类型的节点(Labels不只是Entity的节点) Query参数: entity_types: 逗号分隔的实体类型列表(可选,用于进一步过滤) enrich: 是否获取相关边信息(默认true) """ try: if not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") reader = ZepEntityReader() result = reader.filter_defined_entities( graph_id=graph_id, defined_entity_types=entity_types, enrich_with_edges=enrich ) return jsonify({ "success": True, "data": result.to_dict() }) except Exception as e: logger.error(f"获取图谱实体失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/entities//', methods=['GET']) def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: if not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 reader = ZepEntityReader() entity = reader.get_entity_with_context(graph_id, entity_uuid) if not entity: return jsonify({ "success": False, "error": f"实体不存在: {entity_uuid}" }), 404 return jsonify({ "success": True, "data": entity.to_dict() }) except Exception as e: logger.error(f"获取实体详情失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/entities//by-type/', methods=['GET']) def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: if not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 enrich = request.args.get('enrich', 'true').lower() == 'true' reader = ZepEntityReader() entities = reader.get_entities_by_type( graph_id=graph_id, entity_type=entity_type, enrich_with_edges=enrich ) return jsonify({ "success": True, "data": { "entity_type": entity_type, "count": len(entities), "entities": [e.to_dict() for e in entities] } }) except Exception as e: logger.error(f"获取实体失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 模拟管理接口 ============== @simulation_bp.route('/create', methods=['POST']) def create_simulation(): """ 创建新的模拟 注意:max_rounds等参数由LLM智能生成,无需手动设置 请求(JSON): { "project_id": "proj_xxxx", // 必填 "graph_id": "mirofish_xxxx", // 可选,如不提供则从project获取 "enable_twitter": true, // 可选,默认true "enable_reddit": true // 可选,默认true } 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "project_id": "proj_xxxx", "graph_id": "mirofish_xxxx", "status": "created", "enable_twitter": true, "enable_reddit": true, "created_at": "2025-12-01T10:00:00" } } """ try: data = request.get_json() or {} project_id = data.get('project_id') if not project_id: return jsonify({ "success": False, "error": "请提供 project_id" }), 400 project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {project_id}" }), 404 graph_id = data.get('graph_id') or project.graph_id if not graph_id: return jsonify({ "success": False, "error": "项目尚未构建图谱,请先调用 /api/graph/build" }), 400 manager = SimulationManager() state = manager.create_simulation( project_id=project_id, graph_id=graph_id, enable_twitter=data.get('enable_twitter', True), enable_reddit=data.get('enable_reddit', True), ) return jsonify({ "success": True, "data": state.to_dict() }) except Exception as e: logger.error(f"创建模拟失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 def _check_simulation_prepared(simulation_id: str) -> tuple: """ 检查模拟是否已经准备完成 检查条件: 1. state.json 存在且 status 为 "ready" 2. 必要文件存在:reddit_profiles.json, twitter_profiles.csv, simulation_config.json 注意:运行脚本(run_*.py)保留在 backend/scripts/ 目录,不再复制到模拟目录 Args: simulation_id: 模拟ID Returns: (is_prepared: bool, info: dict) """ import os from ..config import Config simulation_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) # 检查目录是否存在 if not os.path.exists(simulation_dir): return False, {"reason": "模拟目录不存在"} # 必要文件列表(不包括脚本,脚本位于 backend/scripts/) required_files = [ "state.json", "simulation_config.json", "reddit_profiles.json", "twitter_profiles.csv" ] # 检查文件是否存在 existing_files = [] missing_files = [] for f in required_files: file_path = os.path.join(simulation_dir, f) if os.path.exists(file_path): existing_files.append(f) else: missing_files.append(f) if missing_files: return False, { "reason": "缺少必要文件", "missing_files": missing_files, "existing_files": existing_files } # 检查state.json中的状态 state_file = os.path.join(simulation_dir, "state.json") try: import json with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) status = state_data.get("status", "") config_generated = state_data.get("config_generated", False) # 详细日志 logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}") # 如果 config_generated=True 且文件存在,认为准备完成 # 以下状态都说明准备工作已完成: # - ready: 准备完成,可以运行 # - preparing: 如果 config_generated=True 说明已完成 # - running: 正在运行,说明准备早就完成了 # - completed: 运行完成,说明准备早就完成了 # - stopped: 已停止,说明准备早就完成了 # - failed: 运行失败(但准备是完成的) prepared_statuses = ["ready", "preparing", "running", "completed", "stopped", "failed"] if status in prepared_statuses and config_generated: # 获取文件统计信息 profiles_file = os.path.join(simulation_dir, "reddit_profiles.json") config_file = os.path.join(simulation_dir, "simulation_config.json") profiles_count = 0 if os.path.exists(profiles_file): with open(profiles_file, 'r', encoding='utf-8') as f: profiles_data = json.load(f) profiles_count = len(profiles_data) if isinstance(profiles_data, list) else 0 # 如果状态是preparing但文件已完成,自动更新状态为ready if status == "preparing": try: state_data["status"] = "ready" from datetime import datetime state_data["updated_at"] = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state_data, f, ensure_ascii=False, indent=2) logger.info(f"自动更新模拟状态: {simulation_id} preparing -> ready") status = "ready" except Exception as e: logger.warning(f"自动更新状态失败: {e}") logger.info(f"模拟 {simulation_id} 检测结果: 已准备完成 (status={status}, config_generated={config_generated})") return True, { "status": status, "entities_count": state_data.get("entities_count", 0), "profiles_count": profiles_count, "entity_types": state_data.get("entity_types", []), "config_generated": config_generated, "created_at": state_data.get("created_at"), "updated_at": state_data.get("updated_at"), "existing_files": existing_files } else: logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})") return False, { "reason": f"状态不在已准备列表中或config_generated为false: status={status}, config_generated={config_generated}", "status": status, "config_generated": config_generated } except Exception as e: return False, {"reason": f"读取状态文件失败: {str(e)}"} @simulation_bp.route('/prepare', methods=['POST']) def prepare_simulation(): """ 准备模拟环境(异步任务,LLM智能生成所有参数) 这是一个耗时操作,接口会立即返回task_id, 使用 GET /api/simulation/prepare/status 查询进度 特性: - 自动检测已完成的准备工作,避免重复生成 - 如果已准备完成,直接返回已有结果 - 支持强制重新生成(force_regenerate=true) 步骤: 1. 检查是否已有完成的准备工作 2. 从Zep图谱读取并过滤实体 3. 为每个实体生成OASIS Agent Profile(带重试机制) 4. LLM智能生成模拟配置(带重试机制) 5. 保存配置文件和预设脚本 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "entity_types": ["Student", "PublicFigure"], // 可选,指定实体类型 "use_llm_for_profiles": true, // 可选,是否用LLM生成人设 "parallel_profile_count": 5, // 可选,并行生成人设数量,默认5 "force_regenerate": false // 可选,强制重新生成,默认false } 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "task_id": "task_xxxx", // 新任务时返回 "status": "preparing|ready", "message": "准备任务已启动|已有完成的准备工作", "already_prepared": true|false // 是否已准备完成 } } """ import threading import os from ..models.task import TaskManager, TaskStatus from ..config import Config try: data = request.get_json() or {} simulation_id = data.get('simulation_id') if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 # 检查是否强制重新生成 force_regenerate = data.get('force_regenerate', False) logger.info(f"开始处理 /prepare 请求: simulation_id={simulation_id}, force_regenerate={force_regenerate}") # 检查是否已经准备完成(避免重复生成) if not force_regenerate: logger.debug(f"检查模拟 {simulation_id} 是否已准备完成...") is_prepared, prepare_info = _check_simulation_prepared(simulation_id) logger.debug(f"检查结果: is_prepared={is_prepared}, prepare_info={prepare_info}") if is_prepared: logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成") return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "status": "ready", "message": "已有完成的准备工作,无需重复生成", "already_prepared": True, "prepare_info": prepare_info } }) else: logger.info(f"模拟 {simulation_id} 未准备完成,将启动准备任务") # 从项目获取必要信息 project = ProjectManager.get_project(state.project_id) if not project: return jsonify({ "success": False, "error": f"项目不存在: {state.project_id}" }), 404 # 获取模拟需求 simulation_requirement = project.simulation_requirement or "" if not simulation_requirement: return jsonify({ "success": False, "error": "项目缺少模拟需求描述 (simulation_requirement)" }), 400 # 获取文档文本 document_text = ProjectManager.get_extracted_text(state.project_id) or "" entity_types_list = data.get('entity_types') use_llm_for_profiles = data.get('use_llm_for_profiles', True) parallel_profile_count = data.get('parallel_profile_count', 5) # ========== 同步获取实体数量(在后台任务启动前) ========== # 这样前端在调用prepare后立即就能获取到预期Agent总数 try: logger.info(f"同步获取实体数量: graph_id={state.graph_id}") reader = ZepEntityReader() # 快速读取实体(不需要边信息,只统计数量) filtered_preview = reader.filter_defined_entities( graph_id=state.graph_id, defined_entity_types=entity_types_list, enrich_with_edges=False # 不获取边信息,加快速度 ) # 保存实体数量到状态(供前端立即获取) state.entities_count = filtered_preview.filtered_count state.entity_types = list(filtered_preview.entity_types) logger.info(f"预期实体数量: {filtered_preview.filtered_count}, 类型: {filtered_preview.entity_types}") except Exception as e: logger.warning(f"同步获取实体数量失败(将在后台任务中重试): {e}") # 失败不影响后续流程,后台任务会重新获取 # 创建异步任务 task_manager = TaskManager() task_id = task_manager.create_task( task_type="simulation_prepare", metadata={ "simulation_id": simulation_id, "project_id": state.project_id } ) # 更新模拟状态(包含预先获取的实体数量) state.status = SimulationStatus.PREPARING manager._save_simulation_state(state) # 定义后台任务 def run_prepare(): try: task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=0, message="开始准备模拟环境..." ) # 准备模拟(带进度回调) # 存储阶段进度详情 stage_details = {} def progress_callback(stage, progress, message, **kwargs): # 计算总进度 stage_weights = { "reading": (0, 20), # 0-20% "generating_profiles": (20, 70), # 20-70% "generating_config": (70, 90), # 70-90% "copying_scripts": (90, 100) # 90-100% } start, end = stage_weights.get(stage, (0, 100)) current_progress = int(start + (end - start) * progress / 100) # 构建详细进度信息 stage_names = { "reading": "读取图谱实体", "generating_profiles": "生成Agent人设", "generating_config": "生成模拟配置", "copying_scripts": "准备模拟脚本" } stage_index = list(stage_weights.keys()).index(stage) + 1 if stage in stage_weights else 1 total_stages = len(stage_weights) # 更新阶段详情 stage_details[stage] = { "stage_name": stage_names.get(stage, stage), "stage_progress": progress, "current": kwargs.get("current", 0), "total": kwargs.get("total", 0), "item_name": kwargs.get("item_name", "") } # 构建详细进度信息 detail = stage_details[stage] progress_detail_data = { "current_stage": stage, "current_stage_name": stage_names.get(stage, stage), "stage_index": stage_index, "total_stages": total_stages, "stage_progress": progress, "current_item": detail["current"], "total_items": detail["total"], "item_description": message } # 构建简洁消息 if detail["total"] > 0: detailed_message = ( f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: " f"{detail['current']}/{detail['total']} - {message}" ) else: detailed_message = f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: {message}" task_manager.update_task( task_id, progress=current_progress, message=detailed_message, progress_detail=progress_detail_data ) result_state = manager.prepare_simulation( simulation_id=simulation_id, simulation_requirement=simulation_requirement, document_text=document_text, defined_entity_types=entity_types_list, use_llm_for_profiles=use_llm_for_profiles, progress_callback=progress_callback, parallel_profile_count=parallel_profile_count ) # 任务完成 task_manager.complete_task( task_id, result=result_state.to_simple_dict() ) except Exception as e: logger.error(f"准备模拟失败: {str(e)}") task_manager.fail_task(task_id, str(e)) # 更新模拟状态为失败 state = manager.get_simulation(simulation_id) if state: state.status = SimulationStatus.FAILED state.error = str(e) manager._save_simulation_state(state) # 启动后台线程 thread = threading.Thread(target=run_prepare, daemon=True) thread.start() return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "task_id": task_id, "status": "preparing", "message": "准备任务已启动,请通过 /api/simulation/prepare/status 查询进度", "already_prepared": False, "expected_entities_count": state.entities_count, # 预期的Agent总数 "entity_types": state.entity_types # 实体类型列表 } }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 404 except Exception as e: logger.error(f"启动准备任务失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/prepare/status', methods=['POST']) def get_prepare_status(): """ 查询准备任务进度 支持两种查询方式: 1. 通过task_id查询正在进行的任务进度 2. 通过simulation_id检查是否已有完成的准备工作 请求(JSON): { "task_id": "task_xxxx", // 可选,prepare返回的task_id "simulation_id": "sim_xxxx" // 可选,模拟ID(用于检查已完成的准备) } 返回: { "success": true, "data": { "task_id": "task_xxxx", "status": "processing|completed|ready", "progress": 45, "message": "...", "already_prepared": true|false, // 是否已有完成的准备 "prepare_info": {...} // 已准备完成时的详细信息 } } """ from ..models.task import TaskManager try: data = request.get_json() or {} task_id = data.get('task_id') simulation_id = data.get('simulation_id') # 如果提供了simulation_id,先检查是否已准备完成 if simulation_id: is_prepared, prepare_info = _check_simulation_prepared(simulation_id) if is_prepared: return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "status": "ready", "progress": 100, "message": "已有完成的准备工作", "already_prepared": True, "prepare_info": prepare_info } }) # 如果没有task_id,返回错误 if not task_id: if simulation_id: # 有simulation_id但未准备完成 return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "status": "not_started", "progress": 0, "message": "尚未开始准备,请调用 /api/simulation/prepare 开始", "already_prepared": False } }) return jsonify({ "success": False, "error": "请提供 task_id 或 simulation_id" }), 400 task_manager = TaskManager() task = task_manager.get_task(task_id) if not task: # 任务不存在,但如果有simulation_id,检查是否已准备完成 if simulation_id: is_prepared, prepare_info = _check_simulation_prepared(simulation_id) if is_prepared: return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "task_id": task_id, "status": "ready", "progress": 100, "message": "任务已完成(准备工作已存在)", "already_prepared": True, "prepare_info": prepare_info } }) return jsonify({ "success": False, "error": f"任务不存在: {task_id}" }), 404 task_dict = task.to_dict() task_dict["already_prepared"] = False return jsonify({ "success": True, "data": task_dict }) except Exception as e: logger.error(f"查询任务状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e) }), 500 @simulation_bp.route('/', methods=['GET']) def get_simulation(simulation_id: str): """获取模拟状态""" try: manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 result = state.to_dict() # 如果模拟已准备好,附加运行说明 if state.status == SimulationStatus.READY: result["run_instructions"] = manager.get_run_instructions(simulation_id) return jsonify({ "success": True, "data": result }) except Exception as e: logger.error(f"获取模拟状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/list', methods=['GET']) def list_simulations(): """ 列出所有模拟 Query参数: project_id: 按项目ID过滤(可选) """ try: project_id = request.args.get('project_id') manager = SimulationManager() simulations = manager.list_simulations(project_id=project_id) return jsonify({ "success": True, "data": [s.to_dict() for s in simulations], "count": len(simulations) }) except Exception as e: logger.error(f"列出模拟失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 def _get_report_id_for_simulation(simulation_id: str) -> str: """ 获取 simulation 对应的最新 report_id 遍历 reports 目录,找出 simulation_id 匹配的 report, 如果有多个则返回最新的(按 created_at 排序) Args: simulation_id: 模拟ID Returns: report_id 或 None """ import json from datetime import datetime # reports 目录路径:backend/uploads/reports # __file__ 是 app/api/simulation.py,需要向上两级到 backend/ reports_dir = os.path.join(os.path.dirname(__file__), '../../uploads/reports') if not os.path.exists(reports_dir): return None matching_reports = [] try: for report_folder in os.listdir(reports_dir): report_path = os.path.join(reports_dir, report_folder) if not os.path.isdir(report_path): continue meta_file = os.path.join(report_path, "meta.json") if not os.path.exists(meta_file): continue try: with open(meta_file, 'r', encoding='utf-8') as f: meta = json.load(f) if meta.get("simulation_id") == simulation_id: matching_reports.append({ "report_id": meta.get("report_id"), "created_at": meta.get("created_at", ""), "status": meta.get("status", "") }) except Exception: continue if not matching_reports: return None # 按创建时间倒序排序,返回最新的 matching_reports.sort(key=lambda x: x.get("created_at", ""), reverse=True) return matching_reports[0].get("report_id") except Exception as e: logger.warning(f"查找 simulation {simulation_id} 的 report 失败: {e}") return None @simulation_bp.route('/history', methods=['GET']) def get_simulation_history(): """ 获取历史模拟列表(带项目详情) 用于首页历史项目展示,返回包含项目名称、描述等丰富信息的模拟列表 Query参数: limit: 返回数量限制(默认20) 返回: { "success": true, "data": [ { "simulation_id": "sim_xxxx", "project_id": "proj_xxxx", "project_name": "武大舆情分析", "simulation_requirement": "如果武汉大学发布...", "status": "completed", "entities_count": 68, "profiles_count": 68, "entity_types": ["Student", "Professor", ...], "created_at": "2024-12-10", "updated_at": "2024-12-10", "total_rounds": 120, "current_round": 120, "report_id": "report_xxxx", "version": "v1.0.2" }, ... ], "count": 7 } """ try: limit = request.args.get('limit', 20, type=int) manager = SimulationManager() simulations = manager.list_simulations()[:limit] # 增强模拟数据,只从 Simulation 文件读取 enriched_simulations = [] for sim in simulations: sim_dict = sim.to_dict() # 获取模拟配置信息(从 simulation_config.json 读取 simulation_requirement) config = manager.get_simulation_config(sim.simulation_id) if config: sim_dict["simulation_requirement"] = config.get("simulation_requirement", "") time_config = config.get("time_config", {}) sim_dict["total_simulation_hours"] = time_config.get("total_simulation_hours", 0) # 推荐轮数(后备值) recommended_rounds = int( time_config.get("total_simulation_hours", 0) * 60 / max(time_config.get("minutes_per_round", 60), 1) ) else: sim_dict["simulation_requirement"] = "" sim_dict["total_simulation_hours"] = 0 recommended_rounds = 0 # 获取运行状态(从 run_state.json 读取用户设置的实际轮数) run_state = SimulationRunner.get_run_state(sim.simulation_id) if run_state: sim_dict["current_round"] = run_state.current_round sim_dict["runner_status"] = run_state.runner_status.value # 使用用户设置的 total_rounds,若无则使用推荐轮数 sim_dict["total_rounds"] = run_state.total_rounds if run_state.total_rounds > 0 else recommended_rounds else: sim_dict["current_round"] = 0 sim_dict["runner_status"] = "idle" sim_dict["total_rounds"] = recommended_rounds # 获取关联项目的文件列表(最多3个) project = ProjectManager.get_project(sim.project_id) if project and hasattr(project, 'files') and project.files: sim_dict["files"] = [ {"filename": f.get("filename", "未知文件")} for f in project.files[:3] ] else: sim_dict["files"] = [] # 获取关联的 report_id(查找该 simulation 最新的 report) sim_dict["report_id"] = _get_report_id_for_simulation(sim.simulation_id) # 添加版本号 sim_dict["version"] = "v1.0.2" # 格式化日期 try: created_date = sim_dict.get("created_at", "")[:10] sim_dict["created_date"] = created_date except: sim_dict["created_date"] = "" enriched_simulations.append(sim_dict) return jsonify({ "success": True, "data": enriched_simulations, "count": len(enriched_simulations) }) except Exception as e: logger.error(f"获取历史模拟失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//profiles', methods=['GET']) def get_simulation_profiles(simulation_id: str): """ 获取模拟的Agent Profile Query参数: platform: 平台类型(reddit/twitter,默认reddit) """ try: platform = request.args.get('platform', 'reddit') manager = SimulationManager() profiles = manager.get_profiles(simulation_id, platform=platform) return jsonify({ "success": True, "data": { "platform": platform, "count": len(profiles), "profiles": profiles } }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 404 except Exception as e: logger.error(f"获取Profile失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//profiles/realtime', methods=['GET']) def get_simulation_profiles_realtime(simulation_id: str): """ 实时获取模拟的Agent Profile(用于在生成过程中实时查看进度) 与 /profiles 接口的区别: - 直接读取文件,不经过 SimulationManager - 适用于生成过程中的实时查看 - 返回额外的元数据(如文件修改时间、是否正在生成等) Query参数: platform: 平台类型(reddit/twitter,默认reddit) 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "platform": "reddit", "count": 15, "total_expected": 93, // 预期总数(如果有) "is_generating": true, // 是否正在生成 "file_exists": true, "file_modified_at": "2025-12-04T18:20:00", "profiles": [...] } } """ import json import csv from datetime import datetime try: platform = request.args.get('platform', 'reddit') # 获取模拟目录 sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) if not os.path.exists(sim_dir): return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 # 确定文件路径 if platform == "reddit": profiles_file = os.path.join(sim_dir, "reddit_profiles.json") else: profiles_file = os.path.join(sim_dir, "twitter_profiles.csv") # 检查文件是否存在 file_exists = os.path.exists(profiles_file) profiles = [] file_modified_at = None if file_exists: # 获取文件修改时间 file_stat = os.stat(profiles_file) file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat() try: if platform == "reddit": with open(profiles_file, 'r', encoding='utf-8') as f: profiles = json.load(f) else: with open(profiles_file, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) profiles = list(reader) except (json.JSONDecodeError, Exception) as e: logger.warning(f"读取 profiles 文件失败(可能正在写入中): {e}") profiles = [] # 检查是否正在生成(通过 state.json 判断) is_generating = False total_expected = None state_file = os.path.join(sim_dir, "state.json") if os.path.exists(state_file): try: with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) status = state_data.get("status", "") is_generating = status == "preparing" total_expected = state_data.get("entities_count") except Exception: pass return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "platform": platform, "count": len(profiles), "total_expected": total_expected, "is_generating": is_generating, "file_exists": file_exists, "file_modified_at": file_modified_at, "profiles": profiles } }) except Exception as e: logger.error(f"实时获取Profile失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//config/realtime', methods=['GET']) def get_simulation_config_realtime(simulation_id: str): """ 实时获取模拟配置(用于在生成过程中实时查看进度) 与 /config 接口的区别: - 直接读取文件,不经过 SimulationManager - 适用于生成过程中的实时查看 - 返回额外的元数据(如文件修改时间、是否正在生成等) - 即使配置还没生成完也能返回部分信息 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "file_exists": true, "file_modified_at": "2025-12-04T18:20:00", "is_generating": true, // 是否正在生成 "generation_stage": "generating_config", // 当前生成阶段 "config": {...} // 配置内容(如果存在) } } """ import json from datetime import datetime try: # 获取模拟目录 sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) if not os.path.exists(sim_dir): return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 # 配置文件路径 config_file = os.path.join(sim_dir, "simulation_config.json") # 检查文件是否存在 file_exists = os.path.exists(config_file) config = None file_modified_at = None if file_exists: # 获取文件修改时间 file_stat = os.stat(config_file) file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat() try: with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) except (json.JSONDecodeError, Exception) as e: logger.warning(f"读取 config 文件失败(可能正在写入中): {e}") config = None # 检查是否正在生成(通过 state.json 判断) is_generating = False generation_stage = None config_generated = False state_file = os.path.join(sim_dir, "state.json") if os.path.exists(state_file): try: with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) status = state_data.get("status", "") is_generating = status == "preparing" config_generated = state_data.get("config_generated", False) # 判断当前阶段 if is_generating: if state_data.get("profiles_generated", False): generation_stage = "generating_config" else: generation_stage = "generating_profiles" elif status == "ready": generation_stage = "completed" except Exception: pass # 构建返回数据 response_data = { "simulation_id": simulation_id, "file_exists": file_exists, "file_modified_at": file_modified_at, "is_generating": is_generating, "generation_stage": generation_stage, "config_generated": config_generated, "config": config } # 如果配置存在,提取一些关键统计信息 if config: response_data["summary"] = { "total_agents": len(config.get("agent_configs", [])), "simulation_hours": config.get("time_config", {}).get("total_simulation_hours"), "initial_posts_count": len(config.get("event_config", {}).get("initial_posts", [])), "hot_topics_count": len(config.get("event_config", {}).get("hot_topics", [])), "has_twitter_config": "twitter_config" in config, "has_reddit_config": "reddit_config" in config, "generated_at": config.get("generated_at"), "llm_model": config.get("llm_model") } return jsonify({ "success": True, "data": response_data }) except Exception as e: logger.error(f"实时获取Config失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//config', methods=['GET']) def get_simulation_config(simulation_id: str): """ 获取模拟配置(LLM智能生成的完整配置) 返回包含: - time_config: 时间配置(模拟时长、轮次、高峰/低谷时段) - agent_configs: 每个Agent的活动配置(活跃度、发言频率、立场等) - event_config: 事件配置(初始帖子、热点话题) - platform_configs: 平台配置 - generation_reasoning: LLM的配置推理说明 """ try: manager = SimulationManager() config = manager.get_simulation_config(simulation_id) if not config: return jsonify({ "success": False, "error": f"模拟配置不存在,请先调用 /prepare 接口" }), 404 return jsonify({ "success": True, "data": config }) except Exception as e: logger.error(f"获取配置失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//config/download', methods=['GET']) def download_simulation_config(simulation_id: str): """下载模拟配置文件""" try: manager = SimulationManager() sim_dir = manager._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): return jsonify({ "success": False, "error": "配置文件不存在,请先调用 /prepare 接口" }), 404 return send_file( config_path, as_attachment=True, download_name="simulation_config.json" ) except Exception as e: logger.error(f"下载配置失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/script//download', methods=['GET']) def download_simulation_script(script_name: str): """ 下载模拟运行脚本文件(通用脚本,位于 backend/scripts/) script_name可选值: - run_twitter_simulation.py - run_reddit_simulation.py - run_parallel_simulation.py - action_logger.py """ try: # 脚本位于 backend/scripts/ 目录 scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) # 验证脚本名称 allowed_scripts = [ "run_twitter_simulation.py", "run_reddit_simulation.py", "run_parallel_simulation.py", "action_logger.py" ] if script_name not in allowed_scripts: return jsonify({ "success": False, "error": f"未知脚本: {script_name},可选: {allowed_scripts}" }), 400 script_path = os.path.join(scripts_dir, script_name) if not os.path.exists(script_path): return jsonify({ "success": False, "error": f"脚本文件不存在: {script_name}" }), 404 return send_file( script_path, as_attachment=True, download_name=script_name ) except Exception as e: logger.error(f"下载脚本失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== Profile生成接口(独立使用) ============== @simulation_bp.route('/generate-profiles', methods=['POST']) def generate_profiles(): """ 直接从图谱生成OASIS Agent Profile(不创建模拟) 请求(JSON): { "graph_id": "mirofish_xxxx", // 必填 "entity_types": ["Student"], // 可选 "use_llm": true, // 可选 "platform": "reddit" // 可选 } """ try: data = request.get_json() or {} graph_id = data.get('graph_id') if not graph_id: return jsonify({ "success": False, "error": "请提供 graph_id" }), 400 entity_types = data.get('entity_types') use_llm = data.get('use_llm', True) platform = data.get('platform', 'reddit') reader = ZepEntityReader() filtered = reader.filter_defined_entities( graph_id=graph_id, defined_entity_types=entity_types, enrich_with_edges=True ) if filtered.filtered_count == 0: return jsonify({ "success": False, "error": "没有找到符合条件的实体" }), 400 generator = OasisProfileGenerator() profiles = generator.generate_profiles_from_entities( entities=filtered.entities, use_llm=use_llm ) if platform == "reddit": profiles_data = [p.to_reddit_format() for p in profiles] elif platform == "twitter": profiles_data = [p.to_twitter_format() for p in profiles] else: profiles_data = [p.to_dict() for p in profiles] return jsonify({ "success": True, "data": { "platform": platform, "entity_types": list(filtered.entity_types), "count": len(profiles_data), "profiles": profiles_data } }) except Exception as e: logger.error(f"生成Profile失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 模拟运行控制接口 ============== @simulation_bp.route('/start', methods=['POST']) def start_simulation(): """ 开始运行模拟 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "platform": "parallel", // 可选: twitter / reddit / parallel (默认) "max_rounds": 100, // 可选: 最大模拟轮数,用于截断过长的模拟 "enable_graph_memory_update": false, // 可选: 是否将Agent活动动态更新到Zep图谱记忆 "force": false // 可选: 强制重新开始(会停止运行中的模拟并清理日志) } 关于 force 参数: - 启用后,如果模拟正在运行或已完成,会先停止并清理运行日志 - 清理的内容包括:run_state.json, actions.jsonl, simulation.log 等 - 不会清理配置文件(simulation_config.json)和 profile 文件 - 适用于需要重新运行模拟的场景 关于 enable_graph_memory_update: - 启用后,模拟中所有Agent的活动(发帖、评论、点赞等)都会实时更新到Zep图谱 - 这可以让图谱"记住"模拟过程,用于后续分析或AI对话 - 需要模拟关联的项目有有效的 graph_id - 采用批量更新机制,减少API调用次数 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "runner_status": "running", "process_pid": 12345, "twitter_running": true, "reddit_running": true, "started_at": "2025-12-01T10:00:00", "graph_memory_update_enabled": true, // 是否启用了图谱记忆更新 "force_restarted": true // 是否是强制重新开始 } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 platform = data.get('platform', 'parallel') max_rounds = data.get('max_rounds') # 可选:最大模拟轮数 enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新 force = data.get('force', False) # 可选:强制重新开始 # 验证 max_rounds 参数 if max_rounds is not None: try: max_rounds = int(max_rounds) if max_rounds <= 0: return jsonify({ "success": False, "error": "max_rounds 必须是正整数" }), 400 except (ValueError, TypeError): return jsonify({ "success": False, "error": "max_rounds 必须是有效的整数" }), 400 if platform not in ['twitter', 'reddit', 'parallel']: return jsonify({ "success": False, "error": f"无效的平台类型: {platform},可选: twitter/reddit/parallel" }), 400 # 检查模拟是否已准备好 manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, "error": f"模拟不存在: {simulation_id}" }), 404 force_restarted = False # 智能处理状态:如果准备工作已完成,允许重新启动 if state.status != SimulationStatus.READY: # 检查准备工作是否已完成 is_prepared, prepare_info = _check_simulation_prepared(simulation_id) if is_prepared: # 准备工作已完成,检查是否有正在运行的进程 if state.status == SimulationStatus.RUNNING: # 检查模拟进程是否真的在运行 run_state = SimulationRunner.get_run_state(simulation_id) if run_state and run_state.runner_status.value == "running": # 进程确实在运行 if force: # 强制模式:停止运行中的模拟 logger.info(f"强制模式:停止运行中的模拟 {simulation_id}") try: SimulationRunner.stop_simulation(simulation_id) except Exception as e: logger.warning(f"停止模拟时出现警告: {str(e)}") else: return jsonify({ "success": False, "error": f"模拟正在运行中,请先调用 /stop 接口停止,或使用 force=true 强制重新开始" }), 400 # 如果是强制模式,清理运行日志 if force: logger.info(f"强制模式:清理模拟日志 {simulation_id}") cleanup_result = SimulationRunner.cleanup_simulation_logs(simulation_id) if not cleanup_result.get("success"): logger.warning(f"清理日志时出现警告: {cleanup_result.get('errors')}") force_restarted = True # 进程不存在或已结束,重置状态为 ready logger.info(f"模拟 {simulation_id} 准备工作已完成,重置状态为 ready(原状态: {state.status.value})") state.status = SimulationStatus.READY manager._save_simulation_state(state) else: # 准备工作未完成 return jsonify({ "success": False, "error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口" }), 400 # 获取图谱ID(用于图谱记忆更新) graph_id = None if enable_graph_memory_update: # 从模拟状态或项目中获取 graph_id graph_id = state.graph_id if not graph_id: # 尝试从项目中获取 project = ProjectManager.get_project(state.project_id) if project: graph_id = project.graph_id if not graph_id: return jsonify({ "success": False, "error": "启用图谱记忆更新需要有效的 graph_id,请确保项目已构建图谱" }), 400 logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") # 启动模拟 run_state = SimulationRunner.start_simulation( simulation_id=simulation_id, platform=platform, max_rounds=max_rounds, enable_graph_memory_update=enable_graph_memory_update, graph_id=graph_id ) # 更新模拟状态 state.status = SimulationStatus.RUNNING manager._save_simulation_state(state) response_data = run_state.to_dict() if max_rounds: response_data['max_rounds_applied'] = max_rounds response_data['graph_memory_update_enabled'] = enable_graph_memory_update response_data['force_restarted'] = force_restarted if enable_graph_memory_update: response_data['graph_id'] = graph_id return jsonify({ "success": True, "data": response_data }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 400 except Exception as e: logger.error(f"启动模拟失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/stop', methods=['POST']) def stop_simulation(): """ 停止模拟 请求(JSON): { "simulation_id": "sim_xxxx" // 必填,模拟ID } 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "runner_status": "stopped", "completed_at": "2025-12-01T12:00:00" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 run_state = SimulationRunner.stop_simulation(simulation_id) # 更新模拟状态 manager = SimulationManager() state = manager.get_simulation(simulation_id) if state: state.status = SimulationStatus.PAUSED manager._save_simulation_state(state) return jsonify({ "success": True, "data": run_state.to_dict() }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 400 except Exception as e: logger.error(f"停止模拟失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 实时状态监控接口 ============== @simulation_bp.route('//run-status', methods=['GET']) def get_run_status(simulation_id: str): """ 获取模拟运行实时状态(用于前端轮询) 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "runner_status": "running", "current_round": 5, "total_rounds": 144, "progress_percent": 3.5, "simulated_hours": 2, "total_simulation_hours": 72, "twitter_running": true, "reddit_running": true, "twitter_actions_count": 150, "reddit_actions_count": 200, "total_actions_count": 350, "started_at": "2025-12-01T10:00:00", "updated_at": "2025-12-01T10:30:00" } } """ try: run_state = SimulationRunner.get_run_state(simulation_id) if not run_state: return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "runner_status": "idle", "current_round": 0, "total_rounds": 0, "progress_percent": 0, "twitter_actions_count": 0, "reddit_actions_count": 0, "total_actions_count": 0, } }) return jsonify({ "success": True, "data": run_state.to_dict() }) except Exception as e: logger.error(f"获取运行状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//run-status/detail', methods=['GET']) def get_run_status_detail(simulation_id: str): """ 获取模拟运行详细状态(包含所有动作) 用于前端展示实时动态 Query参数: platform: 过滤平台(twitter/reddit,可选) 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "runner_status": "running", "current_round": 5, ... "all_actions": [ { "round_num": 5, "timestamp": "2025-12-01T10:30:00", "platform": "twitter", "agent_id": 3, "agent_name": "Agent Name", "action_type": "CREATE_POST", "action_args": {"content": "..."}, "result": null, "success": true }, ... ], "twitter_actions": [...], # Twitter 平台的所有动作 "reddit_actions": [...] # Reddit 平台的所有动作 } } """ try: run_state = SimulationRunner.get_run_state(simulation_id) platform_filter = request.args.get('platform') if not run_state: return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "runner_status": "idle", "all_actions": [], "twitter_actions": [], "reddit_actions": [] } }) # 获取完整的动作列表 all_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, platform=platform_filter ) # 分平台获取动作 twitter_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, platform="twitter" ) if not platform_filter or platform_filter == "twitter" else [] reddit_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, platform="reddit" ) if not platform_filter or platform_filter == "reddit" else [] # 获取当前轮次的动作(recent_actions 只展示最新一轮) current_round = run_state.current_round recent_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, platform=platform_filter, round_num=current_round ) if current_round > 0 else [] # 获取基础状态信息 result = run_state.to_dict() result["all_actions"] = [a.to_dict() for a in all_actions] result["twitter_actions"] = [a.to_dict() for a in twitter_actions] result["reddit_actions"] = [a.to_dict() for a in reddit_actions] result["rounds_count"] = len(run_state.rounds) # recent_actions 只展示当前最新一轮两个平台的内容 result["recent_actions"] = [a.to_dict() for a in recent_actions] return jsonify({ "success": True, "data": result }) except Exception as e: logger.error(f"获取详细状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//actions', methods=['GET']) def get_simulation_actions(simulation_id: str): """ 获取模拟中的Agent动作历史 Query参数: limit: 返回数量(默认100) offset: 偏移量(默认0) platform: 过滤平台(twitter/reddit) agent_id: 过滤Agent ID round_num: 过滤轮次 返回: { "success": true, "data": { "count": 100, "actions": [...] } } """ try: limit = request.args.get('limit', 100, type=int) offset = request.args.get('offset', 0, type=int) platform = request.args.get('platform') agent_id = request.args.get('agent_id', type=int) round_num = request.args.get('round_num', type=int) actions = SimulationRunner.get_actions( simulation_id=simulation_id, limit=limit, offset=offset, platform=platform, agent_id=agent_id, round_num=round_num ) return jsonify({ "success": True, "data": { "count": len(actions), "actions": [a.to_dict() for a in actions] } }) except Exception as e: logger.error(f"获取动作历史失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//timeline', methods=['GET']) def get_simulation_timeline(simulation_id: str): """ 获取模拟时间线(按轮次汇总) 用于前端展示进度条和时间线视图 Query参数: start_round: 起始轮次(默认0) end_round: 结束轮次(默认全部) 返回每轮的汇总信息 """ try: start_round = request.args.get('start_round', 0, type=int) end_round = request.args.get('end_round', type=int) timeline = SimulationRunner.get_timeline( simulation_id=simulation_id, start_round=start_round, end_round=end_round ) return jsonify({ "success": True, "data": { "rounds_count": len(timeline), "timeline": timeline } }) except Exception as e: logger.error(f"获取时间线失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//agent-stats', methods=['GET']) def get_agent_stats(simulation_id: str): """ 获取每个Agent的统计信息 用于前端展示Agent活跃度排行、动作分布等 """ try: stats = SimulationRunner.get_agent_stats(simulation_id) return jsonify({ "success": True, "data": { "agents_count": len(stats), "stats": stats } }) except Exception as e: logger.error(f"获取Agent统计失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== 数据库查询接口 ============== @simulation_bp.route('//posts', methods=['GET']) def get_simulation_posts(simulation_id: str): """ 获取模拟中的帖子 Query参数: platform: 平台类型(twitter/reddit) limit: 返回数量(默认50) offset: 偏移量 返回帖子列表(从SQLite数据库读取) """ try: platform = request.args.get('platform', 'reddit') limit = request.args.get('limit', 50, type=int) offset = request.args.get('offset', 0, type=int) sim_dir = os.path.join( os.path.dirname(__file__), f'../../uploads/simulations/{simulation_id}' ) db_file = f"{platform}_simulation.db" db_path = os.path.join(sim_dir, db_file) if not os.path.exists(db_path): return jsonify({ "success": True, "data": { "platform": platform, "count": 0, "posts": [], "message": "数据库不存在,模拟可能尚未运行" } }) import sqlite3 conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row cursor = conn.cursor() try: cursor.execute(""" SELECT * FROM post ORDER BY created_at DESC LIMIT ? OFFSET ? """, (limit, offset)) posts = [dict(row) for row in cursor.fetchall()] cursor.execute("SELECT COUNT(*) FROM post") total = cursor.fetchone()[0] except sqlite3.OperationalError: posts = [] total = 0 conn.close() return jsonify({ "success": True, "data": { "platform": platform, "total": total, "count": len(posts), "posts": posts } }) except Exception as e: logger.error(f"获取帖子失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('//comments', methods=['GET']) def get_simulation_comments(simulation_id: str): """ 获取模拟中的评论(仅Reddit) Query参数: post_id: 过滤帖子ID(可选) limit: 返回数量 offset: 偏移量 """ try: post_id = request.args.get('post_id') limit = request.args.get('limit', 50, type=int) offset = request.args.get('offset', 0, type=int) sim_dir = os.path.join( os.path.dirname(__file__), f'../../uploads/simulations/{simulation_id}' ) db_path = os.path.join(sim_dir, "reddit_simulation.db") if not os.path.exists(db_path): return jsonify({ "success": True, "data": { "count": 0, "comments": [] } }) import sqlite3 conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row cursor = conn.cursor() try: if post_id: cursor.execute(""" SELECT * FROM comment WHERE post_id = ? ORDER BY created_at DESC LIMIT ? OFFSET ? """, (post_id, limit, offset)) else: cursor.execute(""" SELECT * FROM comment ORDER BY created_at DESC LIMIT ? OFFSET ? """, (limit, offset)) comments = [dict(row) for row in cursor.fetchall()] except sqlite3.OperationalError: comments = [] conn.close() return jsonify({ "success": True, "data": { "count": len(comments), "comments": comments } }) except Exception as e: logger.error(f"获取评论失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 # ============== Interview 采访接口 ============== @simulation_bp.route('/interview', methods=['POST']) def interview_agent(): """ 采访单个Agent 注意:此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式) 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "agent_id": 0, // 必填,Agent ID "prompt": "你对这件事有什么看法?", // 必填,采访问题 "platform": "twitter", // 可选,指定平台(twitter/reddit) // 不指定时:双平台模拟同时采访两个平台 "timeout": 60 // 可选,超时时间(秒),默认60 } 返回(不指定platform,双平台模式): { "success": true, "data": { "agent_id": 0, "prompt": "你对这件事有什么看法?", "result": { "agent_id": 0, "prompt": "...", "platforms": { "twitter": {"agent_id": 0, "response": "...", "platform": "twitter"}, "reddit": {"agent_id": 0, "response": "...", "platform": "reddit"} } }, "timestamp": "2025-12-08T10:00:01" } } 返回(指定platform): { "success": true, "data": { "agent_id": 0, "prompt": "你对这件事有什么看法?", "result": { "agent_id": 0, "response": "我认为...", "platform": "twitter", "timestamp": "2025-12-08T10:00:00" }, "timestamp": "2025-12-08T10:00:01" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') agent_id = data.get('agent_id') prompt = data.get('prompt') platform = data.get('platform') # 可选:twitter/reddit/None timeout = data.get('timeout', 60) if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 if agent_id is None: return jsonify({ "success": False, "error": "请提供 agent_id" }), 400 if not prompt: return jsonify({ "success": False, "error": "请提供 prompt(采访问题)" }), 400 # 验证platform参数 if platform and platform not in ("twitter", "reddit"): return jsonify({ "success": False, "error": "platform 参数只能是 'twitter' 或 'reddit'" }), 400 # 检查环境状态 if not SimulationRunner.check_env_alive(simulation_id): return jsonify({ "success": False, "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" }), 400 # 优化prompt,添加前缀避免Agent调用工具 optimized_prompt = optimize_interview_prompt(prompt) result = SimulationRunner.interview_agent( simulation_id=simulation_id, agent_id=agent_id, prompt=optimized_prompt, platform=platform, timeout=timeout ) return jsonify({ "success": result.get("success", False), "data": result }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 400 except TimeoutError as e: return jsonify({ "success": False, "error": f"等待Interview响应超时: {str(e)}" }), 504 except Exception as e: logger.error(f"Interview失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/interview/batch', methods=['POST']) def interview_agents_batch(): """ 批量采访多个Agent 注意:此功能需要模拟环境处于运行状态 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "interviews": [ // 必填,采访列表 { "agent_id": 0, "prompt": "你对A有什么看法?", "platform": "twitter" // 可选,指定该Agent的采访平台 }, { "agent_id": 1, "prompt": "你对B有什么看法?" // 不指定platform则使用默认值 } ], "platform": "reddit", // 可选,默认平台(被每项的platform覆盖) // 不指定时:双平台模拟每个Agent同时采访两个平台 "timeout": 120 // 可选,超时时间(秒),默认120 } 返回: { "success": true, "data": { "interviews_count": 2, "result": { "interviews_count": 4, "results": { "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, "twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"}, "reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"} } }, "timestamp": "2025-12-08T10:00:01" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') interviews = data.get('interviews') platform = data.get('platform') # 可选:twitter/reddit/None timeout = data.get('timeout', 120) if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 if not interviews or not isinstance(interviews, list): return jsonify({ "success": False, "error": "请提供 interviews(采访列表)" }), 400 # 验证platform参数 if platform and platform not in ("twitter", "reddit"): return jsonify({ "success": False, "error": "platform 参数只能是 'twitter' 或 'reddit'" }), 400 # 验证每个采访项 for i, interview in enumerate(interviews): if 'agent_id' not in interview: return jsonify({ "success": False, "error": f"采访列表第{i+1}项缺少 agent_id" }), 400 if 'prompt' not in interview: return jsonify({ "success": False, "error": f"采访列表第{i+1}项缺少 prompt" }), 400 # 验证每项的platform(如果有) item_platform = interview.get('platform') if item_platform and item_platform not in ("twitter", "reddit"): return jsonify({ "success": False, "error": f"采访列表第{i+1}项的platform只能是 'twitter' 或 'reddit'" }), 400 # 检查环境状态 if not SimulationRunner.check_env_alive(simulation_id): return jsonify({ "success": False, "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" }), 400 # 优化每个采访项的prompt,添加前缀避免Agent调用工具 optimized_interviews = [] for interview in interviews: optimized_interview = interview.copy() optimized_interview['prompt'] = optimize_interview_prompt(interview.get('prompt', '')) optimized_interviews.append(optimized_interview) result = SimulationRunner.interview_agents_batch( simulation_id=simulation_id, interviews=optimized_interviews, platform=platform, timeout=timeout ) return jsonify({ "success": result.get("success", False), "data": result }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 400 except TimeoutError as e: return jsonify({ "success": False, "error": f"等待批量Interview响应超时: {str(e)}" }), 504 except Exception as e: logger.error(f"批量Interview失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/interview/all', methods=['POST']) def interview_all_agents(): """ 全局采访 - 使用相同问题采访所有Agent 注意:此功能需要模拟环境处于运行状态 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "prompt": "你对这件事整体有什么看法?", // 必填,采访问题(所有Agent使用相同问题) "platform": "reddit", // 可选,指定平台(twitter/reddit) // 不指定时:双平台模拟每个Agent同时采访两个平台 "timeout": 180 // 可选,超时时间(秒),默认180 } 返回: { "success": true, "data": { "interviews_count": 50, "result": { "interviews_count": 100, "results": { "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, ... } }, "timestamp": "2025-12-08T10:00:01" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') prompt = data.get('prompt') platform = data.get('platform') # 可选:twitter/reddit/None timeout = data.get('timeout', 180) if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 if not prompt: return jsonify({ "success": False, "error": "请提供 prompt(采访问题)" }), 400 # 验证platform参数 if platform and platform not in ("twitter", "reddit"): return jsonify({ "success": False, "error": "platform 参数只能是 'twitter' 或 'reddit'" }), 400 # 检查环境状态 if not SimulationRunner.check_env_alive(simulation_id): return jsonify({ "success": False, "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" }), 400 # 优化prompt,添加前缀避免Agent调用工具 optimized_prompt = optimize_interview_prompt(prompt) result = SimulationRunner.interview_all_agents( simulation_id=simulation_id, prompt=optimized_prompt, platform=platform, timeout=timeout ) return jsonify({ "success": result.get("success", False), "data": result }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 400 except TimeoutError as e: return jsonify({ "success": False, "error": f"等待全局Interview响应超时: {str(e)}" }), 504 except Exception as e: logger.error(f"全局Interview失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/interview/history', methods=['POST']) def get_interview_history(): """ 获取Interview历史记录 从模拟数据库中读取所有Interview记录 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "platform": "reddit", // 可选,平台类型(reddit/twitter) // 不指定则返回两个平台的所有历史 "agent_id": 0, // 可选,只获取该Agent的采访历史 "limit": 100 // 可选,返回数量,默认100 } 返回: { "success": true, "data": { "count": 10, "history": [ { "agent_id": 0, "response": "我认为...", "prompt": "你对这件事有什么看法?", "timestamp": "2025-12-08T10:00:00", "platform": "reddit" }, ... ] } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') platform = data.get('platform') # 不指定则返回两个平台的历史 agent_id = data.get('agent_id') limit = data.get('limit', 100) if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 history = SimulationRunner.get_interview_history( simulation_id=simulation_id, platform=platform, agent_id=agent_id, limit=limit ) return jsonify({ "success": True, "data": { "count": len(history), "history": history } }) except Exception as e: logger.error(f"获取Interview历史失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/env-status', methods=['POST']) def get_env_status(): """ 获取模拟环境状态 检查模拟环境是否存活(可以接收Interview命令) 请求(JSON): { "simulation_id": "sim_xxxx" // 必填,模拟ID } 返回: { "success": true, "data": { "simulation_id": "sim_xxxx", "env_alive": true, "twitter_available": true, "reddit_available": true, "message": "环境正在运行,可以接收Interview命令" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 env_alive = SimulationRunner.check_env_alive(simulation_id) # 获取更详细的状态信息 env_status = SimulationRunner.get_env_status_detail(simulation_id) if env_alive: message = "环境正在运行,可以接收Interview命令" else: message = "环境未运行或已关闭" return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "env_alive": env_alive, "twitter_available": env_status.get("twitter_available", False), "reddit_available": env_status.get("reddit_available", False), "message": message } }) except Exception as e: logger.error(f"获取环境状态失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 @simulation_bp.route('/close-env', methods=['POST']) def close_simulation_env(): """ 关闭模拟环境 向模拟发送关闭环境命令,使其优雅退出等待命令模式。 注意:这不同于 /stop 接口,/stop 会强制终止进程, 而此接口会让模拟优雅地关闭环境并退出。 请求(JSON): { "simulation_id": "sim_xxxx", // 必填,模拟ID "timeout": 30 // 可选,超时时间(秒),默认30 } 返回: { "success": true, "data": { "message": "环境关闭命令已发送", "result": {...}, "timestamp": "2025-12-08T10:00:01" } } """ try: data = request.get_json() or {} simulation_id = data.get('simulation_id') timeout = data.get('timeout', 30) if not simulation_id: return jsonify({ "success": False, "error": "请提供 simulation_id" }), 400 result = SimulationRunner.close_simulation_env( simulation_id=simulation_id, timeout=timeout ) # 更新模拟状态 manager = SimulationManager() state = manager.get_simulation(simulation_id) if state: state.status = SimulationStatus.COMPLETED manager._save_simulation_state(state) return jsonify({ "success": result.get("success", False), "data": result }) except ValueError as e: return jsonify({ "success": False, "error": str(e) }), 400 except Exception as e: logger.error(f"关闭环境失败: {str(e)}") return jsonify({ "success": False, "error": str(e), "traceback": traceback.format_exc() }), 500 ================================================ FILE: backend/app/config.py ================================================ """ 配置管理 统一从项目根目录的 .env 文件加载配置 """ import os from dotenv import load_dotenv # 加载项目根目录的 .env 文件 # 路径: MiroFish/.env (相对于 backend/app/config.py) project_root_env = os.path.join(os.path.dirname(__file__), '../../.env') if os.path.exists(project_root_env): load_dotenv(project_root_env, override=True) else: # 如果根目录没有 .env,尝试加载环境变量(用于生产环境) load_dotenv(override=True) class Config: """Flask配置类""" # Flask配置 SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' # JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式) JSON_AS_ASCII = False # LLM配置(统一使用OpenAI格式) LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') # Zep配置 ZEP_API_KEY = os.environ.get('ZEP_API_KEY') # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'} # 文本处理配置 DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 # OASIS模拟配置 OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') # OASIS平台可用动作配置 OASIS_TWITTER_ACTIONS = [ 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST' ] OASIS_REDDIT_ACTIONS = [ 'LIKE_POST', 'DISLIKE_POST', 'CREATE_POST', 'CREATE_COMMENT', 'LIKE_COMMENT', 'DISLIKE_COMMENT', 'SEARCH_POSTS', 'SEARCH_USER', 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE' ] # Report Agent配置 REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5')) REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) @classmethod def validate(cls): """验证必要配置""" errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") if not cls.ZEP_API_KEY: errors.append("ZEP_API_KEY 未配置") return errors ================================================ FILE: backend/app/models/__init__.py ================================================ """ 数据模型模块 """ from .task import TaskManager, TaskStatus from .project import Project, ProjectStatus, ProjectManager __all__ = ['TaskManager', 'TaskStatus', 'Project', 'ProjectStatus', 'ProjectManager'] ================================================ FILE: backend/app/models/project.py ================================================ """ 项目上下文管理 用于在服务端持久化项目状态,避免前端在接口间传递大量数据 """ import os import json import uuid import shutil from datetime import datetime from typing import Dict, Any, List, Optional from enum import Enum from dataclasses import dataclass, field, asdict from ..config import Config class ProjectStatus(str, Enum): """项目状态""" CREATED = "created" # 刚创建,文件已上传 ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成 GRAPH_BUILDING = "graph_building" # 图谱构建中 GRAPH_COMPLETED = "graph_completed" # 图谱构建完成 FAILED = "failed" # 失败 @dataclass class Project: """项目数据模型""" project_id: str name: str status: ProjectStatus created_at: str updated_at: str # 文件信息 files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}] total_text_length: int = 0 # 本体信息(接口1生成后填充) ontology: Optional[Dict[str, Any]] = None analysis_summary: Optional[str] = None # 图谱信息(接口2完成后填充) graph_id: Optional[str] = None graph_build_task_id: Optional[str] = None # 配置 simulation_requirement: Optional[str] = None chunk_size: int = 500 chunk_overlap: int = 50 # 错误信息 error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { "project_id": self.project_id, "name": self.name, "status": self.status.value if isinstance(self.status, ProjectStatus) else self.status, "created_at": self.created_at, "updated_at": self.updated_at, "files": self.files, "total_text_length": self.total_text_length, "ontology": self.ontology, "analysis_summary": self.analysis_summary, "graph_id": self.graph_id, "graph_build_task_id": self.graph_build_task_id, "simulation_requirement": self.simulation_requirement, "chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap, "error": self.error } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Project': """从字典创建""" status = data.get('status', 'created') if isinstance(status, str): status = ProjectStatus(status) return cls( project_id=data['project_id'], name=data.get('name', 'Unnamed Project'), status=status, created_at=data.get('created_at', ''), updated_at=data.get('updated_at', ''), files=data.get('files', []), total_text_length=data.get('total_text_length', 0), ontology=data.get('ontology'), analysis_summary=data.get('analysis_summary'), graph_id=data.get('graph_id'), graph_build_task_id=data.get('graph_build_task_id'), simulation_requirement=data.get('simulation_requirement'), chunk_size=data.get('chunk_size', 500), chunk_overlap=data.get('chunk_overlap', 50), error=data.get('error') ) class ProjectManager: """项目管理器 - 负责项目的持久化存储和检索""" # 项目存储根目录 PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects') @classmethod def _ensure_projects_dir(cls): """确保项目目录存在""" os.makedirs(cls.PROJECTS_DIR, exist_ok=True) @classmethod def _get_project_dir(cls, project_id: str) -> str: """获取项目目录路径""" return os.path.join(cls.PROJECTS_DIR, project_id) @classmethod def _get_project_meta_path(cls, project_id: str) -> str: """获取项目元数据文件路径""" return os.path.join(cls._get_project_dir(project_id), 'project.json') @classmethod def _get_project_files_dir(cls, project_id: str) -> str: """获取项目文件存储目录""" return os.path.join(cls._get_project_dir(project_id), 'files') @classmethod def _get_project_text_path(cls, project_id: str) -> str: """获取项目提取文本存储路径""" return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt') @classmethod def create_project(cls, name: str = "Unnamed Project") -> Project: """ 创建新项目 Args: name: 项目名称 Returns: 新创建的Project对象 """ cls._ensure_projects_dir() project_id = f"proj_{uuid.uuid4().hex[:12]}" now = datetime.now().isoformat() project = Project( project_id=project_id, name=name, status=ProjectStatus.CREATED, created_at=now, updated_at=now ) # 创建项目目录结构 project_dir = cls._get_project_dir(project_id) files_dir = cls._get_project_files_dir(project_id) os.makedirs(project_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True) # 保存项目元数据 cls.save_project(project) return project @classmethod def save_project(cls, project: Project) -> None: """保存项目元数据""" project.updated_at = datetime.now().isoformat() meta_path = cls._get_project_meta_path(project.project_id) with open(meta_path, 'w', encoding='utf-8') as f: json.dump(project.to_dict(), f, ensure_ascii=False, indent=2) @classmethod def get_project(cls, project_id: str) -> Optional[Project]: """ 获取项目 Args: project_id: 项目ID Returns: Project对象,如果不存在返回None """ meta_path = cls._get_project_meta_path(project_id) if not os.path.exists(meta_path): return None with open(meta_path, 'r', encoding='utf-8') as f: data = json.load(f) return Project.from_dict(data) @classmethod def list_projects(cls, limit: int = 50) -> List[Project]: """ 列出所有项目 Args: limit: 返回数量限制 Returns: 项目列表,按创建时间倒序 """ cls._ensure_projects_dir() projects = [] for project_id in os.listdir(cls.PROJECTS_DIR): project = cls.get_project(project_id) if project: projects.append(project) # 按创建时间倒序排序 projects.sort(key=lambda p: p.created_at, reverse=True) return projects[:limit] @classmethod def delete_project(cls, project_id: str) -> bool: """ 删除项目及其所有文件 Args: project_id: 项目ID Returns: 是否删除成功 """ project_dir = cls._get_project_dir(project_id) if not os.path.exists(project_dir): return False shutil.rmtree(project_dir) return True @classmethod def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]: """ 保存上传的文件到项目目录 Args: project_id: 项目ID file_storage: Flask的FileStorage对象 original_filename: 原始文件名 Returns: 文件信息字典 {filename, path, size} """ files_dir = cls._get_project_files_dir(project_id) os.makedirs(files_dir, exist_ok=True) # 生成安全的文件名 ext = os.path.splitext(original_filename)[1].lower() safe_filename = f"{uuid.uuid4().hex[:8]}{ext}" file_path = os.path.join(files_dir, safe_filename) # 保存文件 file_storage.save(file_path) # 获取文件大小 file_size = os.path.getsize(file_path) return { "original_filename": original_filename, "saved_filename": safe_filename, "path": file_path, "size": file_size } @classmethod def save_extracted_text(cls, project_id: str, text: str) -> None: """保存提取的文本""" text_path = cls._get_project_text_path(project_id) with open(text_path, 'w', encoding='utf-8') as f: f.write(text) @classmethod def get_extracted_text(cls, project_id: str) -> Optional[str]: """获取提取的文本""" text_path = cls._get_project_text_path(project_id) if not os.path.exists(text_path): return None with open(text_path, 'r', encoding='utf-8') as f: return f.read() @classmethod def get_project_files(cls, project_id: str) -> List[str]: """获取项目的所有文件路径""" files_dir = cls._get_project_files_dir(project_id) if not os.path.exists(files_dir): return [] return [ os.path.join(files_dir, f) for f in os.listdir(files_dir) if os.path.isfile(os.path.join(files_dir, f)) ] ================================================ FILE: backend/app/models/task.py ================================================ """ 任务状态管理 用于跟踪长时间运行的任务(如图谱构建) """ import uuid import threading from datetime import datetime from enum import Enum from typing import Dict, Any, Optional from dataclasses import dataclass, field class TaskStatus(str, Enum): """任务状态枚举""" PENDING = "pending" # 等待中 PROCESSING = "processing" # 处理中 COMPLETED = "completed" # 已完成 FAILED = "failed" # 失败 @dataclass class Task: """任务数据类""" task_id: str task_type: str status: TaskStatus created_at: datetime updated_at: datetime progress: int = 0 # 总进度百分比 0-100 message: str = "" # 状态消息 result: Optional[Dict] = None # 任务结果 error: Optional[str] = None # 错误信息 metadata: Dict = field(default_factory=dict) # 额外元数据 progress_detail: Dict = field(default_factory=dict) # 详细进度信息 def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { "task_id": self.task_id, "task_type": self.task_type, "status": self.status.value, "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), "progress": self.progress, "message": self.message, "progress_detail": self.progress_detail, "result": self.result, "error": self.error, "metadata": self.metadata, } class TaskManager: """ 任务管理器 线程安全的任务状态管理 """ _instance = None _lock = threading.Lock() def __new__(cls): """单例模式""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._tasks: Dict[str, Task] = {} cls._instance._task_lock = threading.Lock() return cls._instance def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: """ 创建新任务 Args: task_type: 任务类型 metadata: 额外元数据 Returns: 任务ID """ task_id = str(uuid.uuid4()) now = datetime.now() task = Task( task_id=task_id, task_type=task_type, status=TaskStatus.PENDING, created_at=now, updated_at=now, metadata=metadata or {} ) with self._task_lock: self._tasks[task_id] = task return task_id def get_task(self, task_id: str) -> Optional[Task]: """获取任务""" with self._task_lock: return self._tasks.get(task_id) def update_task( self, task_id: str, status: Optional[TaskStatus] = None, progress: Optional[int] = None, message: Optional[str] = None, result: Optional[Dict] = None, error: Optional[str] = None, progress_detail: Optional[Dict] = None ): """ 更新任务状态 Args: task_id: 任务ID status: 新状态 progress: 进度 message: 消息 result: 结果 error: 错误信息 progress_detail: 详细进度信息 """ with self._task_lock: task = self._tasks.get(task_id) if task: task.updated_at = datetime.now() if status is not None: task.status = status if progress is not None: task.progress = progress if message is not None: task.message = message if result is not None: task.result = result if error is not None: task.error = error if progress_detail is not None: task.progress_detail = progress_detail def complete_task(self, task_id: str, result: Dict): """标记任务完成""" self.update_task( task_id, status=TaskStatus.COMPLETED, progress=100, message="任务完成", result=result ) def fail_task(self, task_id: str, error: str): """标记任务失败""" self.update_task( task_id, status=TaskStatus.FAILED, message="任务失败", error=error ) def list_tasks(self, task_type: Optional[str] = None) -> list: """列出任务""" with self._task_lock: tasks = list(self._tasks.values()) if task_type: tasks = [t for t in tasks if t.task_type == task_type] return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] def cleanup_old_tasks(self, max_age_hours: int = 24): """清理旧任务""" from datetime import timedelta cutoff = datetime.now() - timedelta(hours=max_age_hours) with self._task_lock: old_ids = [ tid for tid, task in self._tasks.items() if task.created_at < cutoff and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] ] for tid in old_ids: del self._tasks[tid] ================================================ FILE: backend/app/services/__init__.py ================================================ """ 业务服务模块 """ from .ontology_generator import OntologyGenerator from .graph_builder import GraphBuilderService from .text_processor import TextProcessor from .zep_entity_reader import ZepEntityReader, EntityNode, FilteredEntities from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile from .simulation_manager import SimulationManager, SimulationState, SimulationStatus from .simulation_config_generator import ( SimulationConfigGenerator, SimulationParameters, AgentActivityConfig, TimeSimulationConfig, EventConfig, PlatformConfig ) from .simulation_runner import ( SimulationRunner, SimulationRunState, RunnerStatus, AgentAction, RoundSummary ) from .zep_graph_memory_updater import ( ZepGraphMemoryUpdater, ZepGraphMemoryManager, AgentActivity ) from .simulation_ipc import ( SimulationIPCClient, SimulationIPCServer, IPCCommand, IPCResponse, CommandType, CommandStatus ) __all__ = [ 'OntologyGenerator', 'GraphBuilderService', 'TextProcessor', 'ZepEntityReader', 'EntityNode', 'FilteredEntities', 'OasisProfileGenerator', 'OasisAgentProfile', 'SimulationManager', 'SimulationState', 'SimulationStatus', 'SimulationConfigGenerator', 'SimulationParameters', 'AgentActivityConfig', 'TimeSimulationConfig', 'EventConfig', 'PlatformConfig', 'SimulationRunner', 'SimulationRunState', 'RunnerStatus', 'AgentAction', 'RoundSummary', 'ZepGraphMemoryUpdater', 'ZepGraphMemoryManager', 'AgentActivity', 'SimulationIPCClient', 'SimulationIPCServer', 'IPCCommand', 'IPCResponse', 'CommandType', 'CommandStatus', ] ================================================ FILE: backend/app/services/graph_builder.py ================================================ """ 图谱构建服务 接口2:使用Zep API构建Standalone Graph """ import os import uuid import time import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass from zep_cloud.client import Zep from zep_cloud import EpisodeData, EntityEdgeSourceTarget from ..config import Config from ..models.task import TaskManager, TaskStatus from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor @dataclass class GraphInfo: """图谱信息""" graph_id: str node_count: int edge_count: int entity_types: List[str] def to_dict(self) -> Dict[str, Any]: return { "graph_id": self.graph_id, "node_count": self.node_count, "edge_count": self.edge_count, "entity_types": self.entity_types, } class GraphBuilderService: """ 图谱构建服务 负责调用Zep API构建知识图谱 """ def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") self.client = Zep(api_key=self.api_key) self.task_manager = TaskManager() def build_graph_async( self, text: str, ontology: Dict[str, Any], graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, batch_size: int = 3 ) -> str: """ 异步构建图谱 Args: text: 输入文本 ontology: 本体定义(来自接口1的输出) graph_name: 图谱名称 chunk_size: 文本块大小 chunk_overlap: 块重叠大小 batch_size: 每批发送的块数量 Returns: 任务ID """ # 创建任务 task_id = self.task_manager.create_task( task_type="graph_build", metadata={ "graph_name": graph_name, "chunk_size": chunk_size, "text_length": len(text), } ) # 在后台线程中执行构建 thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size) ) thread.daemon = True thread.start() return task_id def _build_graph_worker( self, task_id: str, text: str, ontology: Dict[str, Any], graph_name: str, chunk_size: int, chunk_overlap: int, batch_size: int ): """图谱构建工作线程""" try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, message="开始构建图谱..." ) # 1. 创建图谱 graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, message=f"图谱已创建: {graph_id}" ) # 2. 设置本体 self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, message="本体已设置" ) # 3. 文本分块 chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, message=f"文本已分割为 {total_chunks} 个块" ) # 4. 分批发送数据 episode_uuids = self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( task_id, progress=20 + int(prog * 0.4), # 20-60% message=msg ) ) # 5. 等待Zep处理完成 self.task_manager.update_task( task_id, progress=60, message="等待Zep处理数据..." ) self._wait_for_episodes( episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, progress=60 + int(prog * 0.3), # 60-90% message=msg ) ) # 6. 获取图谱信息 self.task_manager.update_task( task_id, progress=90, message="获取图谱信息..." ) graph_info = self._get_graph_info(graph_id) # 完成 self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), "chunks_processed": total_chunks, }) except Exception as e: import traceback error_msg = f"{str(e)}\n{traceback.format_exc()}" self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: """创建Zep图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" self.client.graph.create( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """设置图谱本体(公开方法)""" import warnings from typing import Optional from pydantic import Field from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel # 抑制 Pydantic v2 关于 Field(default=None) 的警告 # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') # Zep 保留名称,不能作为属性名 RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} def safe_attr_name(attr_name: str) -> str: """将保留名称转换为安全名称""" if attr_name.lower() in RESERVED_NAMES: return f"entity_{attr_name}" return attr_name # 动态创建实体类型 entity_types = {} for entity_def in ontology.get("entity_types", []): name = entity_def["name"] description = entity_def.get("description", f"A {name} entity.") # 创建属性字典和类型注解(Pydantic v2 需要) attrs = {"__doc__": description} annotations = {} for attr_def in entity_def.get("attributes", []): attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 attr_desc = attr_def.get("description", attr_name) # Zep API 需要 Field 的 description,这是必需的 attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[EntityText] # 类型注解 attrs["__annotations__"] = annotations # 动态创建类 entity_class = type(name, (EntityModel,), attrs) entity_class.__doc__ = description entity_types[name] = entity_class # 动态创建边类型 edge_definitions = {} for edge_def in ontology.get("edge_types", []): name = edge_def["name"] description = edge_def.get("description", f"A {name} relationship.") # 创建属性字典和类型注解 attrs = {"__doc__": description} annotations = {} for attr_def in edge_def.get("attributes", []): attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 attr_desc = attr_def.get("description", attr_name) # Zep API 需要 Field 的 description,这是必需的 attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[str] # 边属性用str类型 attrs["__annotations__"] = annotations # 动态创建类 class_name = ''.join(word.capitalize() for word in name.split('_')) edge_class = type(class_name, (EdgeModel,), attrs) edge_class.__doc__ = description # 构建source_targets source_targets = [] for st in edge_def.get("source_targets", []): source_targets.append( EntityEdgeSourceTarget( source=st.get("source", "Entity"), target=st.get("target", "Entity") ) ) if source_targets: edge_definitions[name] = (edge_class, source_targets) # 调用Zep API设置本体 if entity_types or edge_definitions: self.client.graph.set_ontology( graph_ids=[graph_id], entities=entity_types if entity_types else None, edges=edge_definitions if edge_definitions else None, ) def add_text_batches( self, graph_id: str, chunks: List[str], batch_size: int = 3, progress_callback: Optional[Callable] = None ) -> List[str]: """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" episode_uuids = [] total_chunks = len(chunks) for i in range(0, total_chunks, batch_size): batch_chunks = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size if progress_callback: progress = (i + len(batch_chunks)) / total_chunks progress_callback( f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...", progress ) # 构建episode数据 episodes = [ EpisodeData(data=chunk, type="text") for chunk in batch_chunks ] # 发送到Zep try: batch_result = self.client.graph.add_batch( graph_id=graph_id, episodes=episodes ) # 收集返回的 episode uuid if batch_result and isinstance(batch_result, list): for ep in batch_result: ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) if ep_uuid: episode_uuids.append(ep_uuid) # 避免请求过快 time.sleep(1) except Exception as e: if progress_callback: progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0) raise return episode_uuids def _wait_for_episodes( self, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" if not episode_uuids: if progress_callback: progress_callback("无需等待(没有 episode)", 1.0) return start_time = time.time() pending_episodes = set(episode_uuids) completed_count = 0 total_episodes = len(episode_uuids) if progress_callback: progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0) while pending_episodes: if time.time() - start_time > timeout: if progress_callback: progress_callback( f"部分文本块超时,已完成 {completed_count}/{total_episodes}", completed_count / total_episodes ) break # 检查每个 episode 的处理状态 for ep_uuid in list(pending_episodes): try: episode = self.client.graph.episode.get(uuid_=ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 except Exception as e: # 忽略单个查询错误,继续 pass elapsed = int(time.time() - start_time) if progress_callback: progress_callback( f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)", completed_count / total_episodes if total_episodes > 0 else 0 ) if pending_episodes: time.sleep(3) # 每3秒检查一次 if progress_callback: progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0) def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" # 获取节点(分页) nodes = fetch_all_nodes(self.client, graph_id) # 获取边(分页) edges = fetch_all_edges(self.client, graph_id) # 统计实体类型 entity_types = set() for node in nodes: if node.labels: for label in node.labels: if label not in ["Entity", "Node"]: entity_types.add(label) return GraphInfo( graph_id=graph_id, node_count=len(nodes), edge_count=len(edges), entity_types=list(entity_types) ) def get_graph_data(self, graph_id: str) -> Dict[str, Any]: """ 获取完整图谱数据(包含详细信息) Args: graph_id: 图谱ID Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ nodes = fetch_all_nodes(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id) # 创建节点映射用于获取节点名称 node_map = {} for node in nodes: node_map[node.uuid_] = node.name or "" nodes_data = [] for node in nodes: # 获取创建时间 created_at = getattr(node, 'created_at', None) if created_at: created_at = str(created_at) nodes_data.append({ "uuid": node.uuid_, "name": node.name, "labels": node.labels or [], "summary": node.summary or "", "attributes": node.attributes or {}, "created_at": created_at, }) edges_data = [] for edge in edges: # 获取时间信息 created_at = getattr(edge, 'created_at', None) valid_at = getattr(edge, 'valid_at', None) invalid_at = getattr(edge, 'invalid_at', None) expired_at = getattr(edge, 'expired_at', None) # 获取 episodes episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) if episodes and not isinstance(episodes, list): episodes = [str(episodes)] elif episodes: episodes = [str(e) for e in episodes] # 获取 fact_type fact_type = getattr(edge, 'fact_type', None) or edge.name or "" edges_data.append({ "uuid": edge.uuid_, "name": edge.name or "", "fact": edge.fact or "", "fact_type": fact_type, "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "source_node_name": node_map.get(edge.source_node_uuid, ""), "target_node_name": node_map.get(edge.target_node_uuid, ""), "attributes": edge.attributes or {}, "created_at": str(created_at) if created_at else None, "valid_at": str(valid_at) if valid_at else None, "invalid_at": str(invalid_at) if invalid_at else None, "expired_at": str(expired_at) if expired_at else None, "episodes": episodes or [], }) return { "graph_id": graph_id, "nodes": nodes_data, "edges": edges_data, "node_count": len(nodes_data), "edge_count": len(edges_data), } def delete_graph(self, graph_id: str): """删除图谱""" self.client.graph.delete(graph_id=graph_id) ================================================ FILE: backend/app/services/oasis_profile_generator.py ================================================ """ OASIS Agent Profile生成器 将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 优化改进: 1. 调用Zep检索功能二次丰富节点信息 2. 优化提示词生成非常详细的人设 3. 区分个人实体和抽象群体实体 """ import json import random import time from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime from openai import OpenAI from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.oasis_profile') @dataclass class OasisAgentProfile: """OASIS Agent Profile数据结构""" # 通用字段 user_id: int user_name: str name: str bio: str persona: str # 可选字段 - Reddit风格 karma: int = 1000 # 可选字段 - Twitter风格 friend_count: int = 100 follower_count: int = 150 statuses_count: int = 500 # 额外人设信息 age: Optional[int] = None gender: Optional[str] = None mbti: Optional[str] = None country: Optional[str] = None profession: Optional[str] = None interested_topics: List[str] = field(default_factory=list) # 来源实体信息 source_entity_uuid: Optional[str] = None source_entity_type: Optional[str] = None created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) def to_reddit_format(self) -> Dict[str, Any]: """转换为Reddit平台格式""" profile = { "user_id": self.user_id, "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) "name": self.name, "bio": self.bio, "persona": self.persona, "karma": self.karma, "created_at": self.created_at, } # 添加额外人设信息(如果有) if self.age: profile["age"] = self.age if self.gender: profile["gender"] = self.gender if self.mbti: profile["mbti"] = self.mbti if self.country: profile["country"] = self.country if self.profession: profile["profession"] = self.profession if self.interested_topics: profile["interested_topics"] = self.interested_topics return profile def to_twitter_format(self) -> Dict[str, Any]: """转换为Twitter平台格式""" profile = { "user_id": self.user_id, "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) "name": self.name, "bio": self.bio, "persona": self.persona, "friend_count": self.friend_count, "follower_count": self.follower_count, "statuses_count": self.statuses_count, "created_at": self.created_at, } # 添加额外人设信息 if self.age: profile["age"] = self.age if self.gender: profile["gender"] = self.gender if self.mbti: profile["mbti"] = self.mbti if self.country: profile["country"] = self.country if self.profession: profile["profession"] = self.profession if self.interested_topics: profile["interested_topics"] = self.interested_topics return profile def to_dict(self) -> Dict[str, Any]: """转换为完整字典格式""" return { "user_id": self.user_id, "user_name": self.user_name, "name": self.name, "bio": self.bio, "persona": self.persona, "karma": self.karma, "friend_count": self.friend_count, "follower_count": self.follower_count, "statuses_count": self.statuses_count, "age": self.age, "gender": self.gender, "mbti": self.mbti, "country": self.country, "profession": self.profession, "interested_topics": self.interested_topics, "source_entity_uuid": self.source_entity_uuid, "source_entity_type": self.source_entity_type, "created_at": self.created_at, } class OasisProfileGenerator: """ OASIS Profile生成器 将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile 优化特性: 1. 调用Zep图谱检索功能获取更丰富的上下文 2. 生成非常详细的人设(包括基本信息、职业经历、性格特征、社交媒体行为等) 3. 区分个人实体和抽象群体实体 """ # MBTI类型列表 MBTI_TYPES = [ "INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", "ISTJ", "ISFJ", "ESTJ", "ESFJ", "ISTP", "ISFP", "ESTP", "ESFP" ] # 常见国家列表 COUNTRIES = [ "China", "US", "UK", "Japan", "Germany", "France", "Canada", "Australia", "Brazil", "India", "South Korea" ] # 个人类型实体(需要生成具体人设) INDIVIDUAL_ENTITY_TYPES = [ "student", "alumni", "professor", "person", "publicfigure", "expert", "faculty", "official", "journalist", "activist" ] # 群体/机构类型实体(需要生成群体代表人设) GROUP_ENTITY_TYPES = [ "university", "governmentagency", "organization", "ngo", "mediaoutlet", "company", "institution", "group", "community" ] def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = None, zep_api_key: Optional[str] = None, graph_id: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: raise ValueError("LLM_API_KEY 未配置") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) # Zep客户端用于检索丰富上下文 self.zep_api_key = zep_api_key or Config.ZEP_API_KEY self.zep_client = None self.graph_id = graph_id if self.zep_api_key: try: self.zep_client = Zep(api_key=self.zep_api_key) except Exception as e: logger.warning(f"Zep客户端初始化失败: {e}") def generate_profile_from_entity( self, entity: EntityNode, user_id: int, use_llm: bool = True ) -> OasisAgentProfile: """ 从Zep实体生成OASIS Agent Profile Args: entity: Zep实体节点 user_id: 用户ID(用于OASIS) use_llm: 是否使用LLM生成详细人设 Returns: OasisAgentProfile """ entity_type = entity.get_entity_type() or "Entity" # 基础信息 name = entity.name user_name = self._generate_username(name) # 构建上下文信息 context = self._build_entity_context(entity) if use_llm: # 使用LLM生成详细人设 profile_data = self._generate_profile_with_llm( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes, context=context ) else: # 使用规则生成基础人设 profile_data = self._generate_profile_rule_based( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes ) return OasisAgentProfile( user_id=user_id, user_name=user_name, name=name, bio=profile_data.get("bio", f"{entity_type}: {name}"), persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."), karma=profile_data.get("karma", random.randint(500, 5000)), friend_count=profile_data.get("friend_count", random.randint(50, 500)), follower_count=profile_data.get("follower_count", random.randint(100, 1000)), statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)), age=profile_data.get("age"), gender=profile_data.get("gender"), mbti=profile_data.get("mbti"), country=profile_data.get("country"), profession=profile_data.get("profession"), interested_topics=profile_data.get("interested_topics", []), source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) def _generate_username(self, name: str) -> str: """生成用户名""" # 移除特殊字符,转换为小写 username = name.lower().replace(" ", "_") username = ''.join(c for c in username if c.isalnum() or c == '_') # 添加随机后缀避免重复 suffix = random.randint(100, 999) return f"{username}_{suffix}" def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: """ 使用Zep图谱混合搜索功能获取实体相关的丰富信息 Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。 使用并行请求同时搜索,提高效率。 Args: entity: 实体节点对象 Returns: 包含facts, node_summaries, context的字典 """ import concurrent.futures if not self.zep_client: return {"facts": [], "node_summaries": [], "context": ""} entity_name = entity.name results = { "facts": [], "node_summaries": [], "context": "" } # 必须有graph_id才能进行搜索 if not self.graph_id: logger.debug(f"跳过Zep检索:未设置graph_id") return results comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景" def search_edges(): """搜索边(事实/关系)- 带重试机制""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: return self.zep_client.graph.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, scope="edges", reranker="rrf" ) except Exception as e: last_exception = e if attempt < max_retries - 1: logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") time.sleep(delay) delay *= 2 else: logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}") return None def search_nodes(): """搜索节点(实体摘要)- 带重试机制""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: return self.zep_client.graph.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, scope="nodes", reranker="rrf" ) except Exception as e: last_exception = e if attempt < max_retries - 1: logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") time.sleep(delay) delay *= 2 else: logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}") return None try: # 并行执行edges和nodes搜索 with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: edge_future = executor.submit(search_edges) node_future = executor.submit(search_nodes) # 获取结果 edge_result = edge_future.result(timeout=30) node_result = node_future.result(timeout=30) # 处理边搜索结果 all_facts = set() if edge_result and hasattr(edge_result, 'edges') and edge_result.edges: for edge in edge_result.edges: if hasattr(edge, 'fact') and edge.fact: all_facts.add(edge.fact) results["facts"] = list(all_facts) # 处理节点搜索结果 all_summaries = set() if node_result and hasattr(node_result, 'nodes') and node_result.nodes: for node in node_result.nodes: if hasattr(node, 'summary') and node.summary: all_summaries.add(node.summary) if hasattr(node, 'name') and node.name and node.name != entity_name: all_summaries.add(f"相关实体: {node.name}") results["node_summaries"] = list(all_summaries) # 构建综合上下文 context_parts = [] if results["facts"]: context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) if results["node_summaries"]: context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) results["context"] = "\n\n".join(context_parts) logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点") except concurrent.futures.TimeoutError: logger.warning(f"Zep检索超时 ({entity_name})") except Exception as e: logger.warning(f"Zep检索失败 ({entity_name}): {e}") return results def _build_entity_context(self, entity: EntityNode) -> str: """ 构建实体的完整上下文信息 包括: 1. 实体本身的边信息(事实) 2. 关联节点的详细信息 3. Zep混合检索到的丰富信息 """ context_parts = [] # 1. 添加实体属性信息 if entity.attributes: attrs = [] for key, value in entity.attributes.items(): if value and str(value).strip(): attrs.append(f"- {key}: {value}") if attrs: context_parts.append("### 实体属性\n" + "\n".join(attrs)) # 2. 添加相关边信息(事实/关系) existing_facts = set() if entity.related_edges: relationships = [] for edge in entity.related_edges: # 不限制数量 fact = edge.get("fact", "") edge_name = edge.get("edge_name", "") direction = edge.get("direction", "") if fact: relationships.append(f"- {fact}") existing_facts.add(fact) elif edge_name: if direction == "outgoing": relationships.append(f"- {entity.name} --[{edge_name}]--> (相关实体)") else: relationships.append(f"- (相关实体) --[{edge_name}]--> {entity.name}") if relationships: context_parts.append("### 相关事实和关系\n" + "\n".join(relationships)) # 3. 添加关联节点的详细信息 if entity.related_nodes: related_info = [] for node in entity.related_nodes: # 不限制数量 node_name = node.get("name", "") node_labels = node.get("labels", []) node_summary = node.get("summary", "") # 过滤掉默认标签 custom_labels = [l for l in node_labels if l not in ["Entity", "Node"]] label_str = f" ({', '.join(custom_labels)})" if custom_labels else "" if node_summary: related_info.append(f"- **{node_name}**{label_str}: {node_summary}") else: related_info.append(f"- **{node_name}**{label_str}") if related_info: context_parts.append("### 关联实体信息\n" + "\n".join(related_info)) # 4. 使用Zep混合检索获取更丰富的信息 zep_results = self._search_zep_for_entity(entity) if zep_results.get("facts"): # 去重:排除已存在的事实 new_facts = [f for f in zep_results["facts"] if f not in existing_facts] if new_facts: context_parts.append("### Zep检索到的事实信息\n" + "\n".join(f"- {f}" for f in new_facts[:15])) if zep_results.get("node_summaries"): context_parts.append("### Zep检索到的相关节点\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) return "\n\n".join(context_parts) def _is_individual_entity(self, entity_type: str) -> bool: """判断是否是个人类型实体""" return entity_type.lower() in self.INDIVIDUAL_ENTITY_TYPES def _is_group_entity(self, entity_type: str) -> bool: """判断是否是群体/机构类型实体""" return entity_type.lower() in self.GROUP_ENTITY_TYPES def _generate_profile_with_llm( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> Dict[str, Any]: """ 使用LLM生成非常详细的人设 根据实体类型区分: - 个人实体:生成具体的人物设定 - 群体/机构实体:生成代表性账号设定 """ is_individual = self._is_individual_entity(entity_type) if is_individual: prompt = self._build_individual_persona_prompt( entity_name, entity_type, entity_summary, entity_attributes, context ) else: prompt = self._build_group_persona_prompt( entity_name, entity_type, entity_summary, entity_attributes, context ) # 尝试多次生成,直到成功或达到最大重试次数 max_attempts = 3 last_error = None for attempt in range(max_attempts): try: response = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": self._get_system_prompt(is_individual)}, {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 # 不设置max_tokens,让LLM自由发挥 ) content = response.choices[0].message.content # 检查是否被截断(finish_reason不是'stop') finish_reason = response.choices[0].finish_reason if finish_reason == 'length': logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") content = self._fix_truncated_json(content) # 尝试解析JSON try: result = json.loads(content) # 验证必需字段 if "bio" not in result or not result["bio"]: result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" if "persona" not in result or not result["persona"]: result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" return result except json.JSONDecodeError as je: logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}") # 尝试修复JSON result = self._try_fix_json(content, entity_name, entity_type, entity_summary) if result.get("_fixed"): del result["_fixed"] return result last_error = je except Exception as e: logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time time.sleep(1 * (attempt + 1)) # 指数退避 logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成") return self._generate_profile_rule_based( entity_name, entity_type, entity_summary, entity_attributes ) def _fix_truncated_json(self, content: str) -> str: """修复被截断的JSON(输出被max_tokens限制截断)""" import re # 如果JSON被截断,尝试闭合它 content = content.strip() # 计算未闭合的括号 open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') # 检查是否有未闭合的字符串 # 简单检查:如果最后一个引号后没有逗号或闭合括号,可能是字符串被截断 if content and content[-1] not in '",}]': # 尝试闭合字符串 content += '"' # 闭合括号 content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_json(self, content: str, entity_name: str, entity_type: str, entity_summary: str = "") -> Dict[str, Any]: """尝试修复损坏的JSON""" import re # 1. 首先尝试修复被截断的情况 content = self._fix_truncated_json(content) # 2. 尝试提取JSON部分 json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() # 3. 处理字符串中的换行符问题 # 找到所有字符串值并替换其中的换行符 def fix_string_newlines(match): s = match.group(0) # 替换字符串内的实际换行符为空格 s = s.replace('\n', ' ').replace('\r', ' ') # 替换多余空格 s = re.sub(r'\s+', ' ', s) return s # 匹配JSON字符串值 json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) # 4. 尝试解析 try: result = json.loads(json_str) result["_fixed"] = True return result except json.JSONDecodeError as e: # 5. 如果还是失败,尝试更激进的修复 try: # 移除所有控制字符 json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) # 替换所有连续空白 json_str = re.sub(r'\s+', ' ', json_str) result = json.loads(json_str) result["_fixed"] = True return result except: pass # 6. 尝试从内容中提取部分信息 bio_match = re.search(r'"bio"\s*:\s*"([^"]*)"', content) persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # 可能被截断 bio = bio_match.group(1) if bio_match else (entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}") persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name}是一个{entity_type}。") # 如果提取到了有意义的内容,标记为已修复 if bio_match or persona_match: logger.info(f"从损坏的JSON中提取了部分信息") return { "bio": bio, "persona": persona, "_fixed": True } # 7. 完全失败,返回基础结构 logger.warning(f"JSON修复失败,返回基础结构") return { "bio": entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name}是一个{entity_type}。" } def _get_system_prompt(self, is_individual: bool) -> str: """获取系统提示词""" base_prompt = "你是社交媒体用户画像生成专家。生成详细、真实的人设用于舆论模拟,最大程度还原已有现实情况。必须返回有效的JSON格式,所有字符串值不能包含未转义的换行符。使用中文。" return base_prompt def _build_individual_persona_prompt( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> str: """构建个人实体的详细人设提示词""" attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "无" context_str = context[:3000] if context else "无额外上下文" return f"""为实体生成详细的社交媒体用户人设,最大程度还原已有现实情况。 实体名称: {entity_name} 实体类型: {entity_type} 实体摘要: {entity_summary} 实体属性: {attrs_str} 上下文信息: {context_str} 请生成JSON,包含以下字段: 1. bio: 社交媒体简介,200字 2. persona: 详细人设描述(2000字的纯文本),需包含: - 基本信息(年龄、职业、教育背景、所在地) - 人物背景(重要经历、与事件的关联、社会关系) - 性格特征(MBTI类型、核心性格、情绪表达方式) - 社交媒体行为(发帖频率、内容偏好、互动风格、语言特点) - 立场观点(对话题的态度、可能被激怒/感动的内容) - 独特特征(口头禅、特殊经历、个人爱好) - 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应) 3. age: 年龄数字(必须是整数) 4. gender: 性别,必须是英文: "male" 或 "female" 5. mbti: MBTI类型(如INTJ、ENFP等) 6. country: 国家(使用中文,如"中国") 7. profession: 职业 8. interested_topics: 感兴趣话题数组 重要: - 所有字段值必须是字符串或数字,不要使用换行符 - persona必须是一段连贯的文字描述 - 使用中文(除了gender字段必须用英文male/female) - 内容要与实体信息保持一致 - age必须是有效的整数,gender必须是"male"或"female" """ def _build_group_persona_prompt( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> str: """构建群体/机构实体的详细人设提示词""" attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "无" context_str = context[:3000] if context else "无额外上下文" return f"""为机构/群体实体生成详细的社交媒体账号设定,最大程度还原已有现实情况。 实体名称: {entity_name} 实体类型: {entity_type} 实体摘要: {entity_summary} 实体属性: {attrs_str} 上下文信息: {context_str} 请生成JSON,包含以下字段: 1. bio: 官方账号简介,200字,专业得体 2. persona: 详细账号设定描述(2000字的纯文本),需包含: - 机构基本信息(正式名称、机构性质、成立背景、主要职能) - 账号定位(账号类型、目标受众、核心功能) - 发言风格(语言特点、常用表达、禁忌话题) - 发布内容特点(内容类型、发布频率、活跃时间段) - 立场态度(对核心话题的官方立场、面对争议的处理方式) - 特殊说明(代表的群体画像、运营习惯) - 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应) 3. age: 固定填30(机构账号的虚拟年龄) 4. gender: 固定填"other"(机构账号使用other表示非个人) 5. mbti: MBTI类型,用于描述账号风格,如ISTJ代表严谨保守 6. country: 国家(使用中文,如"中国") 7. profession: 机构职能描述 8. interested_topics: 关注领域数组 重要: - 所有字段值必须是字符串或数字,不允许null值 - persona必须是一段连贯的文字描述,不要使用换行符 - 使用中文(除了gender字段必须用英文"other") - age必须是整数30,gender必须是字符串"other" - 机构账号发言要符合其身份定位""" def _generate_profile_rule_based( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any] ) -> Dict[str, Any]: """使用规则生成基础人设""" # 根据实体类型生成不同的人设 entity_type_lower = entity_type.lower() if entity_type_lower in ["student", "alumni"]: return { "bio": f"{entity_type} with interests in academics and social issues.", "persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.", "age": random.randint(18, 30), "gender": random.choice(["male", "female"]), "mbti": random.choice(self.MBTI_TYPES), "country": random.choice(self.COUNTRIES), "profession": "Student", "interested_topics": ["Education", "Social Issues", "Technology"], } elif entity_type_lower in ["publicfigure", "expert", "faculty"]: return { "bio": f"Expert and thought leader in their field.", "persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.", "age": random.randint(35, 60), "gender": random.choice(["male", "female"]), "mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]), "country": random.choice(self.COUNTRIES), "profession": entity_attributes.get("occupation", "Expert"), "interested_topics": ["Politics", "Economics", "Culture & Society"], } elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]: return { "bio": f"Official account for {entity_name}. News and updates.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", "age": 30, # 机构虚拟年龄 "gender": "other", # 机构使用other "mbti": "ISTJ", # 机构风格:严谨保守 "country": "中国", "profession": "Media", "interested_topics": ["General News", "Current Events", "Public Affairs"], } elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]: return { "bio": f"Official account of {entity_name}.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", "age": 30, # 机构虚拟年龄 "gender": "other", # 机构使用other "mbti": "ISTJ", # 机构风格:严谨保守 "country": "中国", "profession": entity_type, "interested_topics": ["Public Policy", "Community", "Official Announcements"], } else: # 默认人设 return { "bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.", "age": random.randint(25, 50), "gender": random.choice(["male", "female"]), "mbti": random.choice(self.MBTI_TYPES), "country": random.choice(self.COUNTRIES), "profession": entity_type, "interested_topics": ["General", "Social Issues"], } def set_graph_id(self, graph_id: str): """设置图谱ID用于Zep检索""" self.graph_id = graph_id def generate_profiles_from_entities( self, entities: List[EntityNode], use_llm: bool = True, progress_callback: Optional[callable] = None, graph_id: Optional[str] = None, parallel_count: int = 5, realtime_output_path: Optional[str] = None, output_platform: str = "reddit" ) -> List[OasisAgentProfile]: """ 批量从实体生成Agent Profile(支持并行生成) Args: entities: 实体列表 use_llm: 是否使用LLM生成详细人设 progress_callback: 进度回调函数 (current, total, message) graph_id: 图谱ID,用于Zep检索获取更丰富上下文 parallel_count: 并行生成数量,默认5 realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次) output_platform: 输出平台格式 ("reddit" 或 "twitter") Returns: Agent Profile列表 """ import concurrent.futures from threading import Lock # 设置graph_id用于Zep检索 if graph_id: self.graph_id = graph_id total = len(entities) profiles = [None] * total # 预分配列表保持顺序 completed_count = [0] # 使用列表以便在闭包中修改 lock = Lock() # 实时写入文件的辅助函数 def save_profiles_realtime(): """实时保存已生成的 profiles 到文件""" if not realtime_output_path: return with lock: # 过滤出已生成的 profiles existing_profiles = [p for p in profiles if p is not None] if not existing_profiles: return try: if output_platform == "reddit": # Reddit JSON 格式 profiles_data = [p.to_reddit_format() for p in existing_profiles] with open(realtime_output_path, 'w', encoding='utf-8') as f: json.dump(profiles_data, f, ensure_ascii=False, indent=2) else: # Twitter CSV 格式 import csv profiles_data = [p.to_twitter_format() for p in existing_profiles] if profiles_data: fieldnames = list(profiles_data[0].keys()) with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(profiles_data) except Exception as e: logger.warning(f"实时保存 profiles 失败: {e}") def generate_single_profile(idx: int, entity: EntityNode) -> tuple: """生成单个profile的工作函数""" entity_type = entity.get_entity_type() or "Entity" try: profile = self.generate_profile_from_entity( entity=entity, user_id=idx, use_llm=use_llm ) # 实时输出生成的人设到控制台和日志 self._print_generated_profile(entity.name, entity_type, profile) return idx, profile, None except Exception as e: logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}") # 创建一个基础profile fallback_profile = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), name=entity.name, bio=f"{entity_type}: {entity.name}", persona=entity.summary or f"A participant in social discussions.", source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) return idx, fallback_profile, str(e) logger.info(f"开始并行生成 {total} 个Agent人设(并行数: {parallel_count})...") print(f"\n{'='*60}") print(f"开始生成Agent人设 - 共 {total} 个实体,并行数: {parallel_count}") print(f"{'='*60}\n") # 使用线程池并行执行 with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_count) as executor: # 提交所有任务 future_to_entity = { executor.submit(generate_single_profile, idx, entity): (idx, entity) for idx, entity in enumerate(entities) } # 收集结果 for future in concurrent.futures.as_completed(future_to_entity): idx, entity = future_to_entity[future] entity_type = entity.get_entity_type() or "Entity" try: result_idx, profile, error = future.result() profiles[result_idx] = profile with lock: completed_count[0] += 1 current = completed_count[0] # 实时写入文件 save_profiles_realtime() if progress_callback: progress_callback( current, total, f"已完成 {current}/{total}: {entity.name}({entity_type})" ) if error: logger.warning(f"[{current}/{total}] {entity.name} 使用备用人设: {error}") else: logger.info(f"[{current}/{total}] 成功生成人设: {entity.name} ({entity_type})") except Exception as e: logger.error(f"处理实体 {entity.name} 时发生异常: {str(e)}") with lock: completed_count[0] += 1 profiles[idx] = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), name=entity.name, bio=f"{entity_type}: {entity.name}", persona=entity.summary or "A participant in social discussions.", source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) # 实时写入文件(即使是备用人设) save_profiles_realtime() print(f"\n{'='*60}") print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent") print(f"{'='*60}\n") return profiles def _print_generated_profile(self, entity_name: str, entity_type: str, profile: OasisAgentProfile): """实时输出生成的人设到控制台(完整内容,不截断)""" separator = "-" * 70 # 构建完整输出内容(不截断) topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else '无' output_lines = [ f"\n{separator}", f"[已生成] {entity_name} ({entity_type})", f"{separator}", f"用户名: {profile.user_name}", f"", f"【简介】", f"{profile.bio}", f"", f"【详细人设】", f"{profile.persona}", f"", f"【基本属性】", f"年龄: {profile.age} | 性别: {profile.gender} | MBTI: {profile.mbti}", f"职业: {profile.profession} | 国家: {profile.country}", f"兴趣话题: {topics_str}", separator ] output = "\n".join(output_lines) # 只输出到控制台(避免重复,logger不再输出完整内容) print(output) def save_profiles( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): """ 保存Profile到文件(根据平台选择正确格式) OASIS平台格式要求: - Twitter: CSV格式 - Reddit: JSON格式 Args: profiles: Profile列表 file_path: 文件路径 platform: 平台类型 ("reddit" 或 "twitter") """ if platform == "twitter": self._save_twitter_csv(profiles, file_path) else: self._save_reddit_json(profiles, file_path) def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str): """ 保存Twitter Profile为CSV格式(符合OASIS官方要求) OASIS Twitter要求的CSV字段: - user_id: 用户ID(根据CSV顺序从0开始) - name: 用户真实姓名 - username: 系统中的用户名 - user_char: 详细人设描述(注入到LLM系统提示中,指导Agent行为) - description: 简短的公开简介(显示在用户资料页面) user_char vs description 区别: - user_char: 内部使用,LLM系统提示,决定Agent如何思考和行动 - description: 外部显示,其他用户可见的简介 """ import csv # 确保文件扩展名是.csv if not file_path.endswith('.csv'): file_path = file_path.replace('.json', '.csv') with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) # 写入OASIS要求的表头 headers = ['user_id', 'name', 'username', 'user_char', 'description'] writer.writerow(headers) # 写入数据行 for idx, profile in enumerate(profiles): # user_char: 完整人设(bio + persona),用于LLM系统提示 user_char = profile.bio if profile.persona and profile.persona != profile.bio: user_char = f"{profile.bio} {profile.persona}" # 处理换行符(CSV中用空格替代) user_char = user_char.replace('\n', ' ').replace('\r', ' ') # description: 简短简介,用于外部显示 description = profile.bio.replace('\n', ' ').replace('\r', ' ') row = [ idx, # user_id: 从0开始的顺序ID profile.name, # name: 真实姓名 profile.user_name, # username: 用户名 user_char, # user_char: 完整人设(内部LLM使用) description # description: 简短简介(外部显示) ] writer.writerow(row) logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)") def _normalize_gender(self, gender: Optional[str]) -> str: """ 标准化gender字段为OASIS要求的英文格式 OASIS要求: male, female, other """ if not gender: return "other" gender_lower = gender.lower().strip() # 中文映射 gender_map = { "男": "male", "女": "female", "机构": "other", "其他": "other", # 英文已有 "male": "male", "female": "female", "other": "other", } return gender_map.get(gender_lower, "other") def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): """ 保存Reddit Profile为JSON格式 使用与 to_reddit_format() 一致的格式,确保 OASIS 能正确读取。 必须包含 user_id 字段,这是 OASIS agent_graph.get_agent() 匹配的关键! 必需字段: - user_id: 用户ID(整数,用于匹配 initial_posts 中的 poster_agent_id) - username: 用户名 - name: 显示名称 - bio: 简介 - persona: 详细人设 - age: 年龄(整数) - gender: "male", "female", 或 "other" - mbti: MBTI类型 - country: 国家 """ data = [] for idx, profile in enumerate(profiles): # 使用与 to_reddit_format() 一致的格式 item = { "user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id "username": profile.user_name, "name": profile.name, "bio": profile.bio[:150] if profile.bio else f"{profile.name}", "persona": profile.persona or f"{profile.name} is a participant in social discussions.", "karma": profile.karma if profile.karma else 1000, "created_at": profile.created_at, # OASIS必需字段 - 确保都有默认值 "age": profile.age if profile.age else 30, "gender": self._normalize_gender(profile.gender), "mbti": profile.mbti if profile.mbti else "ISTJ", "country": profile.country if profile.country else "中国", } # 可选字段 if profile.profession: item["profession"] = profile.profession if profile.interested_topics: item["interested_topics"] = profile.interested_topics data.append(item) with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON格式,包含user_id字段)") # 保留旧方法名作为别名,保持向后兼容 def save_profiles_to_json( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): """[已废弃] 请使用 save_profiles() 方法""" logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") self.save_profiles(profiles, file_path, platform) ================================================ FILE: backend/app/services/ontology_generator.py ================================================ """ 本体生成服务 接口1:分析文本内容,生成适合社会模拟的实体和关系类型定义 """ import json from typing import Dict, Any, List, Optional from ..utils.llm_client import LLMClient # 本体生成的系统提示词 ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。 **重要:你必须输出有效的JSON格式数据,不要输出任何其他内容。** ## 核心任务背景 我们正在构建一个**社交媒体舆论模拟系统**。在这个系统中: - 每个实体都是一个可以在社交媒体上发声、互动、传播信息的"账号"或"主体" - 实体之间会相互影响、转发、评论、回应 - 我们需要模拟舆论事件中各方的反应和信息传播路径 因此,**实体必须是现实中真实存在的、可以在社媒上发声和互动的主体**: **可以是**: - 具体的个人(公众人物、当事人、意见领袖、专家学者、普通人) - 公司、企业(包括其官方账号) - 组织机构(大学、协会、NGO、工会等) - 政府部门、监管机构 - 媒体机构(报纸、电视台、自媒体、网站) - 社交媒体平台本身 - 特定群体代表(如校友会、粉丝团、维权群体等) **不可以是**: - 抽象概念(如"舆论"、"情绪"、"趋势") - 主题/话题(如"学术诚信"、"教育改革") - 观点/态度(如"支持方"、"反对方") ## 输出格式 请输出JSON格式,包含以下结构: ```json { "entity_types": [ { "name": "实体类型名称(英文,PascalCase)", "description": "简短描述(英文,不超过100字符)", "attributes": [ { "name": "属性名(英文,snake_case)", "type": "text", "description": "属性描述" } ], "examples": ["示例实体1", "示例实体2"] } ], "edge_types": [ { "name": "关系类型名称(英文,UPPER_SNAKE_CASE)", "description": "简短描述(英文,不超过100字符)", "source_targets": [ {"source": "源实体类型", "target": "目标实体类型"} ], "attributes": [] } ], "analysis_summary": "对文本内容的简要分析说明(中文)" } ``` ## 设计指南(极其重要!) ### 1. 实体类型设计 - 必须严格遵守 **数量要求:必须正好10个实体类型** **层次结构要求(必须同时包含具体类型和兜底类型)**: 你的10个实体类型必须包含以下层次: A. **兜底类型(必须包含,放在列表最后2个)**: - `Person`: 任何自然人个体的兜底类型。当一个人不属于其他更具体的人物类型时,归入此类。 - `Organization`: 任何组织机构的兜底类型。当一个组织不属于其他更具体的组织类型时,归入此类。 B. **具体类型(8个,根据文本内容设计)**: - 针对文本中出现的主要角色,设计更具体的类型 - 例如:如果文本涉及学术事件,可以有 `Student`, `Professor`, `University` - 例如:如果文本涉及商业事件,可以有 `Company`, `CEO`, `Employee` **为什么需要兜底类型**: - 文本中会出现各种人物,如"中小学教师"、"路人甲"、"某位网友" - 如果没有专门的类型匹配,他们应该被归入 `Person` - 同理,小型组织、临时团体等应该归入 `Organization` **具体类型的设计原则**: - 从文本中识别出高频出现或关键的角色类型 - 每个具体类型应该有明确的边界,避免重叠 - description 必须清晰说明这个类型和兜底类型的区别 ### 2. 关系类型设计 - 数量:6-10个 - 关系应该反映社媒互动中的真实联系 - 确保关系的 source_targets 涵盖你定义的实体类型 ### 3. 属性设计 - 每个实体类型1-3个关键属性 - **注意**:属性名不能使用 `name`、`uuid`、`group_id`、`created_at`、`summary`(这些是系统保留字) - 推荐使用:`full_name`, `title`, `role`, `position`, `location`, `description` 等 ## 实体类型参考 **个人类(具体)**: - Student: 学生 - Professor: 教授/学者 - Journalist: 记者 - Celebrity: 明星/网红 - Executive: 高管 - Official: 政府官员 - Lawyer: 律师 - Doctor: 医生 **个人类(兜底)**: - Person: 任何自然人(不属于上述具体类型时使用) **组织类(具体)**: - University: 高校 - Company: 公司企业 - GovernmentAgency: 政府机构 - MediaOutlet: 媒体机构 - Hospital: 医院 - School: 中小学 - NGO: 非政府组织 **组织类(兜底)**: - Organization: 任何组织机构(不属于上述具体类型时使用) ## 关系类型参考 - WORKS_FOR: 工作于 - STUDIES_AT: 就读于 - AFFILIATED_WITH: 隶属于 - REPRESENTS: 代表 - REGULATES: 监管 - REPORTS_ON: 报道 - COMMENTS_ON: 评论 - RESPONDS_TO: 回应 - SUPPORTS: 支持 - OPPOSES: 反对 - COLLABORATES_WITH: 合作 - COMPETES_WITH: 竞争 """ class OntologyGenerator: """ 本体生成器 分析文本内容,生成实体和关系类型定义 """ def __init__(self, llm_client: Optional[LLMClient] = None): self.llm_client = llm_client or LLMClient() def generate( self, document_texts: List[str], simulation_requirement: str, additional_context: Optional[str] = None ) -> Dict[str, Any]: """ 生成本体定义 Args: document_texts: 文档文本列表 simulation_requirement: 模拟需求描述 additional_context: 额外上下文 Returns: 本体定义(entity_types, edge_types等) """ # 构建用户消息 user_message = self._build_user_message( document_texts, simulation_requirement, additional_context ) messages = [ {"role": "system", "content": ONTOLOGY_SYSTEM_PROMPT}, {"role": "user", "content": user_message} ] # 调用LLM result = self.llm_client.chat_json( messages=messages, temperature=0.3, max_tokens=4096 ) # 验证和后处理 result = self._validate_and_process(result) return result # 传给 LLM 的文本最大长度(5万字) MAX_TEXT_LENGTH_FOR_LLM = 50000 def _build_user_message( self, document_texts: List[str], simulation_requirement: str, additional_context: Optional[str] ) -> str: """构建用户消息""" # 合并文本 combined_text = "\n\n---\n\n".join(document_texts) original_length = len(combined_text) # 如果文本超过5万字,截断(仅影响传给LLM的内容,不影响图谱构建) if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM: combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM] combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..." message = f"""## 模拟需求 {simulation_requirement} ## 文档内容 {combined_text} """ if additional_context: message += f""" ## 额外说明 {additional_context} """ message += """ 请根据以上内容,设计适合社会舆论模拟的实体类型和关系类型。 **必须遵守的规则**: 1. 必须正好输出10个实体类型 2. 最后2个必须是兜底类型:Person(个人兜底)和 Organization(组织兜底) 3. 前8个是根据文本内容设计的具体类型 4. 所有实体类型必须是现实中可以发声的主体,不能是抽象概念 5. 属性名不能使用 name、uuid、group_id 等保留字,用 full_name、org_name 等替代 """ return message def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: """验证和后处理结果""" # 确保必要字段存在 if "entity_types" not in result: result["entity_types"] = [] if "edge_types" not in result: result["edge_types"] = [] if "analysis_summary" not in result: result["analysis_summary"] = "" # 验证实体类型 for entity in result["entity_types"]: if "attributes" not in entity: entity["attributes"] = [] if "examples" not in entity: entity["examples"] = [] # 确保description不超过100字符 if len(entity.get("description", "")) > 100: entity["description"] = entity["description"][:97] + "..." # 验证关系类型 for edge in result["edge_types"]: if "source_targets" not in edge: edge["source_targets"] = [] if "attributes" not in edge: edge["attributes"] = [] if len(edge.get("description", "")) > 100: edge["description"] = edge["description"][:97] + "..." # Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 MAX_ENTITY_TYPES = 10 MAX_EDGE_TYPES = 10 # 兜底类型定义 person_fallback = { "name": "Person", "description": "Any individual person not fitting other specific person types.", "attributes": [ {"name": "full_name", "type": "text", "description": "Full name of the person"}, {"name": "role", "type": "text", "description": "Role or occupation"} ], "examples": ["ordinary citizen", "anonymous netizen"] } organization_fallback = { "name": "Organization", "description": "Any organization not fitting other specific organization types.", "attributes": [ {"name": "org_name", "type": "text", "description": "Name of the organization"}, {"name": "org_type", "type": "text", "description": "Type of organization"} ], "examples": ["small business", "community group"] } # 检查是否已有兜底类型 entity_names = {e["name"] for e in result["entity_types"]} has_person = "Person" in entity_names has_organization = "Organization" in entity_names # 需要添加的兜底类型 fallbacks_to_add = [] if not has_person: fallbacks_to_add.append(person_fallback) if not has_organization: fallbacks_to_add.append(organization_fallback) if fallbacks_to_add: current_count = len(result["entity_types"]) needed_slots = len(fallbacks_to_add) # 如果添加后会超过 10 个,需要移除一些现有类型 if current_count + needed_slots > MAX_ENTITY_TYPES: # 计算需要移除多少个 to_remove = current_count + needed_slots - MAX_ENTITY_TYPES # 从末尾移除(保留前面更重要的具体类型) result["entity_types"] = result["entity_types"][:-to_remove] # 添加兜底类型 result["entity_types"].extend(fallbacks_to_add) # 最终确保不超过限制(防御性编程) if len(result["entity_types"]) > MAX_ENTITY_TYPES: result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES] if len(result["edge_types"]) > MAX_EDGE_TYPES: result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES] return result def generate_python_code(self, ontology: Dict[str, Any]) -> str: """ 将本体定义转换为Python代码(类似ontology.py) Args: ontology: 本体定义 Returns: Python代码字符串 """ code_lines = [ '"""', '自定义实体类型定义', '由MiroFish自动生成,用于社会舆论模拟', '"""', '', 'from pydantic import Field', 'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel', '', '', '# ============== 实体类型定义 ==============', '', ] # 生成实体类型 for entity in ontology.get("entity_types", []): name = entity["name"] desc = entity.get("description", f"A {name} entity.") code_lines.append(f'class {name}(EntityModel):') code_lines.append(f' """{desc}"""') attrs = entity.get("attributes", []) if attrs: for attr in attrs: attr_name = attr["name"] attr_desc = attr.get("description", attr_name) code_lines.append(f' {attr_name}: EntityText = Field(') code_lines.append(f' description="{attr_desc}",') code_lines.append(f' default=None') code_lines.append(f' )') else: code_lines.append(' pass') code_lines.append('') code_lines.append('') code_lines.append('# ============== 关系类型定义 ==============') code_lines.append('') # 生成关系类型 for edge in ontology.get("edge_types", []): name = edge["name"] # 转换为PascalCase类名 class_name = ''.join(word.capitalize() for word in name.split('_')) desc = edge.get("description", f"A {name} relationship.") code_lines.append(f'class {class_name}(EdgeModel):') code_lines.append(f' """{desc}"""') attrs = edge.get("attributes", []) if attrs: for attr in attrs: attr_name = attr["name"] attr_desc = attr.get("description", attr_name) code_lines.append(f' {attr_name}: EntityText = Field(') code_lines.append(f' description="{attr_desc}",') code_lines.append(f' default=None') code_lines.append(f' )') else: code_lines.append(' pass') code_lines.append('') code_lines.append('') # 生成类型字典 code_lines.append('# ============== 类型配置 ==============') code_lines.append('') code_lines.append('ENTITY_TYPES = {') for entity in ontology.get("entity_types", []): name = entity["name"] code_lines.append(f' "{name}": {name},') code_lines.append('}') code_lines.append('') code_lines.append('EDGE_TYPES = {') for edge in ontology.get("edge_types", []): name = edge["name"] class_name = ''.join(word.capitalize() for word in name.split('_')) code_lines.append(f' "{name}": {class_name},') code_lines.append('}') code_lines.append('') # 生成边的source_targets映射 code_lines.append('EDGE_SOURCE_TARGETS = {') for edge in ontology.get("edge_types", []): name = edge["name"] source_targets = edge.get("source_targets", []) if source_targets: st_list = ', '.join([ f'{{"source": "{st.get("source", "Entity")}", "target": "{st.get("target", "Entity")}"}}' for st in source_targets ]) code_lines.append(f' "{name}": [{st_list}],') code_lines.append('}') return '\n'.join(code_lines) ================================================ FILE: backend/app/services/report_agent.py ================================================ """ Report Agent服务 使用LangChain + Zep实现ReACT模式的模拟报告生成 功能: 1. 根据模拟需求和Zep图谱信息生成报告 2. 先规划目录结构,然后分段生成 3. 每段采用ReACT多轮思考与反思模式 4. 支持与用户对话,在对话中自主调用检索工具 """ import os import json import time import re from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass, field from datetime import datetime from enum import Enum from ..config import Config from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from .zep_tools import ( ZepToolsService, SearchResult, InsightForgeResult, PanoramaResult, InterviewResult ) logger = get_logger('mirofish.report_agent') class ReportLogger: """ Report Agent 详细日志记录器 在报告文件夹中生成 agent_log.jsonl 文件,记录每一步详细动作。 每行是一个完整的 JSON 对象,包含时间戳、动作类型、详细内容等。 """ def __init__(self, report_id: str): """ 初始化日志记录器 Args: report_id: 报告ID,用于确定日志文件路径 """ self.report_id = report_id self.log_file_path = os.path.join( Config.UPLOAD_FOLDER, 'reports', report_id, 'agent_log.jsonl' ) self.start_time = datetime.now() self._ensure_log_file() def _ensure_log_file(self): """确保日志文件所在目录存在""" log_dir = os.path.dirname(self.log_file_path) os.makedirs(log_dir, exist_ok=True) def _get_elapsed_time(self) -> float: """获取从开始到现在的耗时(秒)""" return (datetime.now() - self.start_time).total_seconds() def log( self, action: str, stage: str, details: Dict[str, Any], section_title: str = None, section_index: int = None ): """ 记录一条日志 Args: action: 动作类型,如 'start', 'tool_call', 'llm_response', 'section_complete' 等 stage: 当前阶段,如 'planning', 'generating', 'completed' details: 详细内容字典,不截断 section_title: 当前章节标题(可选) section_index: 当前章节索引(可选) """ log_entry = { "timestamp": datetime.now().isoformat(), "elapsed_seconds": round(self._get_elapsed_time(), 2), "report_id": self.report_id, "action": action, "stage": stage, "section_title": section_title, "section_index": section_index, "details": details } # 追加写入 JSONL 文件 with open(self.log_file_path, 'a', encoding='utf-8') as f: f.write(json.dumps(log_entry, ensure_ascii=False) + '\n') def log_start(self, simulation_id: str, graph_id: str, simulation_requirement: str): """记录报告生成开始""" self.log( action="report_start", stage="pending", details={ "simulation_id": simulation_id, "graph_id": graph_id, "simulation_requirement": simulation_requirement, "message": "报告生成任务开始" } ) def log_planning_start(self): """记录大纲规划开始""" self.log( action="planning_start", stage="planning", details={"message": "开始规划报告大纲"} ) def log_planning_context(self, context: Dict[str, Any]): """记录规划时获取的上下文信息""" self.log( action="planning_context", stage="planning", details={ "message": "获取模拟上下文信息", "context": context } ) def log_planning_complete(self, outline_dict: Dict[str, Any]): """记录大纲规划完成""" self.log( action="planning_complete", stage="planning", details={ "message": "大纲规划完成", "outline": outline_dict } ) def log_section_start(self, section_title: str, section_index: int): """记录章节生成开始""" self.log( action="section_start", stage="generating", section_title=section_title, section_index=section_index, details={"message": f"开始生成章节: {section_title}"} ) def log_react_thought(self, section_title: str, section_index: int, iteration: int, thought: str): """记录 ReACT 思考过程""" self.log( action="react_thought", stage="generating", section_title=section_title, section_index=section_index, details={ "iteration": iteration, "thought": thought, "message": f"ReACT 第{iteration}轮思考" } ) def log_tool_call( self, section_title: str, section_index: int, tool_name: str, parameters: Dict[str, Any], iteration: int ): """记录工具调用""" self.log( action="tool_call", stage="generating", section_title=section_title, section_index=section_index, details={ "iteration": iteration, "tool_name": tool_name, "parameters": parameters, "message": f"调用工具: {tool_name}" } ) def log_tool_result( self, section_title: str, section_index: int, tool_name: str, result: str, iteration: int ): """记录工具调用结果(完整内容,不截断)""" self.log( action="tool_result", stage="generating", section_title=section_title, section_index=section_index, details={ "iteration": iteration, "tool_name": tool_name, "result": result, # 完整结果,不截断 "result_length": len(result), "message": f"工具 {tool_name} 返回结果" } ) def log_llm_response( self, section_title: str, section_index: int, response: str, iteration: int, has_tool_calls: bool, has_final_answer: bool ): """记录 LLM 响应(完整内容,不截断)""" self.log( action="llm_response", stage="generating", section_title=section_title, section_index=section_index, details={ "iteration": iteration, "response": response, # 完整响应,不截断 "response_length": len(response), "has_tool_calls": has_tool_calls, "has_final_answer": has_final_answer, "message": f"LLM 响应 (工具调用: {has_tool_calls}, 最终答案: {has_final_answer})" } ) def log_section_content( self, section_title: str, section_index: int, content: str, tool_calls_count: int ): """记录章节内容生成完成(仅记录内容,不代表整个章节完成)""" self.log( action="section_content", stage="generating", section_title=section_title, section_index=section_index, details={ "content": content, # 完整内容,不截断 "content_length": len(content), "tool_calls_count": tool_calls_count, "message": f"章节 {section_title} 内容生成完成" } ) def log_section_full_complete( self, section_title: str, section_index: int, full_content: str ): """ 记录章节生成完成 前端应监听此日志来判断一个章节是否真正完成,并获取完整内容 """ self.log( action="section_complete", stage="generating", section_title=section_title, section_index=section_index, details={ "content": full_content, "content_length": len(full_content), "message": f"章节 {section_title} 生成完成" } ) def log_report_complete(self, total_sections: int, total_time_seconds: float): """记录报告生成完成""" self.log( action="report_complete", stage="completed", details={ "total_sections": total_sections, "total_time_seconds": round(total_time_seconds, 2), "message": "报告生成完成" } ) def log_error(self, error_message: str, stage: str, section_title: str = None): """记录错误""" self.log( action="error", stage=stage, section_title=section_title, section_index=None, details={ "error": error_message, "message": f"发生错误: {error_message}" } ) class ReportConsoleLogger: """ Report Agent 控制台日志记录器 将控制台风格的日志(INFO、WARNING等)写入报告文件夹中的 console_log.txt 文件。 这些日志与 agent_log.jsonl 不同,是纯文本格式的控制台输出。 """ def __init__(self, report_id: str): """ 初始化控制台日志记录器 Args: report_id: 报告ID,用于确定日志文件路径 """ self.report_id = report_id self.log_file_path = os.path.join( Config.UPLOAD_FOLDER, 'reports', report_id, 'console_log.txt' ) self._ensure_log_file() self._file_handler = None self._setup_file_handler() def _ensure_log_file(self): """确保日志文件所在目录存在""" log_dir = os.path.dirname(self.log_file_path) os.makedirs(log_dir, exist_ok=True) def _setup_file_handler(self): """设置文件处理器,将日志同时写入文件""" import logging # 创建文件处理器 self._file_handler = logging.FileHandler( self.log_file_path, mode='a', encoding='utf-8' ) self._file_handler.setLevel(logging.INFO) # 使用与控制台相同的简洁格式 formatter = logging.Formatter( '[%(asctime)s] %(levelname)s: %(message)s', datefmt='%H:%M:%S' ) self._file_handler.setFormatter(formatter) # 添加到 report_agent 相关的 logger loggers_to_attach = [ 'mirofish.report_agent', 'mirofish.zep_tools', ] for logger_name in loggers_to_attach: target_logger = logging.getLogger(logger_name) # 避免重复添加 if self._file_handler not in target_logger.handlers: target_logger.addHandler(self._file_handler) def close(self): """关闭文件处理器并从 logger 中移除""" import logging if self._file_handler: loggers_to_detach = [ 'mirofish.report_agent', 'mirofish.zep_tools', ] for logger_name in loggers_to_detach: target_logger = logging.getLogger(logger_name) if self._file_handler in target_logger.handlers: target_logger.removeHandler(self._file_handler) self._file_handler.close() self._file_handler = None def __del__(self): """析构时确保关闭文件处理器""" self.close() class ReportStatus(str, Enum): """报告状态""" PENDING = "pending" PLANNING = "planning" GENERATING = "generating" COMPLETED = "completed" FAILED = "failed" @dataclass class ReportSection: """报告章节""" title: str content: str = "" def to_dict(self) -> Dict[str, Any]: return { "title": self.title, "content": self.content } def to_markdown(self, level: int = 2) -> str: """转换为Markdown格式""" md = f"{'#' * level} {self.title}\n\n" if self.content: md += f"{self.content}\n\n" return md @dataclass class ReportOutline: """报告大纲""" title: str summary: str sections: List[ReportSection] def to_dict(self) -> Dict[str, Any]: return { "title": self.title, "summary": self.summary, "sections": [s.to_dict() for s in self.sections] } def to_markdown(self) -> str: """转换为Markdown格式""" md = f"# {self.title}\n\n" md += f"> {self.summary}\n\n" for section in self.sections: md += section.to_markdown() return md @dataclass class Report: """完整报告""" report_id: str simulation_id: str graph_id: str simulation_requirement: str status: ReportStatus outline: Optional[ReportOutline] = None markdown_content: str = "" created_at: str = "" completed_at: str = "" error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { "report_id": self.report_id, "simulation_id": self.simulation_id, "graph_id": self.graph_id, "simulation_requirement": self.simulation_requirement, "status": self.status.value, "outline": self.outline.to_dict() if self.outline else None, "markdown_content": self.markdown_content, "created_at": self.created_at, "completed_at": self.completed_at, "error": self.error } # ═══════════════════════════════════════════════════════════════ # Prompt 模板常量 # ═══════════════════════════════════════════════════════════════ # ── 工具描述 ── TOOL_DESC_INSIGHT_FORGE = """\ 【深度洞察检索 - 强大的检索工具】 这是我们强大的检索函数,专为深度分析设计。它会: 1. 自动将你的问题分解为多个子问题 2. 从多个维度检索模拟图谱中的信息 3. 整合语义搜索、实体分析、关系链追踪的结果 4. 返回最全面、最深度的检索内容 【使用场景】 - 需要深入分析某个话题 - 需要了解事件的多个方面 - 需要获取支撑报告章节的丰富素材 【返回内容】 - 相关事实原文(可直接引用) - 核心实体洞察 - 关系链分析""" TOOL_DESC_PANORAMA_SEARCH = """\ 【广度搜索 - 获取全貌视图】 这个工具用于获取模拟结果的完整全貌,特别适合了解事件演变过程。它会: 1. 获取所有相关节点和关系 2. 区分当前有效的事实和历史/过期的事实 3. 帮助你了解舆情是如何演变的 【使用场景】 - 需要了解事件的完整发展脉络 - 需要对比不同阶段的舆情变化 - 需要获取全面的实体和关系信息 【返回内容】 - 当前有效事实(模拟最新结果) - 历史/过期事实(演变记录) - 所有涉及的实体""" TOOL_DESC_QUICK_SEARCH = """\ 【简单搜索 - 快速检索】 轻量级的快速检索工具,适合简单、直接的信息查询。 【使用场景】 - 需要快速查找某个具体信息 - 需要验证某个事实 - 简单的信息检索 【返回内容】 - 与查询最相关的事实列表""" TOOL_DESC_INTERVIEW_AGENTS = """\ 【深度采访 - 真实Agent采访(双平台)】 调用OASIS模拟环境的采访API,对正在运行的模拟Agent进行真实采访! 这不是LLM模拟,而是调用真实的采访接口获取模拟Agent的原始回答。 默认在Twitter和Reddit两个平台同时采访,获取更全面的观点。 功能流程: 1. 自动读取人设文件,了解所有模拟Agent 2. 智能选择与采访主题最相关的Agent(如学生、媒体、官方等) 3. 自动生成采访问题 4. 调用 /api/simulation/interview/batch 接口在双平台进行真实采访 5. 整合所有采访结果,提供多视角分析 【使用场景】 - 需要从不同角色视角了解事件看法(学生怎么看?媒体怎么看?官方怎么说?) - 需要收集多方意见和立场 - 需要获取模拟Agent的真实回答(来自OASIS模拟环境) - 想让报告更生动,包含"采访实录" 【返回内容】 - 被采访Agent的身份信息 - 各Agent在Twitter和Reddit两个平台的采访回答 - 关键引言(可直接引用) - 采访摘要和观点对比 【重要】需要OASIS模拟环境正在运行才能使用此功能!""" # ── 大纲规划 prompt ── PLAN_SYSTEM_PROMPT = """\ 你是一个「未来预测报告」的撰写专家,拥有对模拟世界的「上帝视角」——你可以洞察模拟中每一位Agent的行为、言论和互动。 【核心理念】 我们构建了一个模拟世界,并向其中注入了特定的「模拟需求」作为变量。模拟世界的演化结果,就是对未来可能发生情况的预测。你正在观察的不是"实验数据",而是"未来的预演"。 【你的任务】 撰写一份「未来预测报告」,回答: 1. 在我们设定的条件下,未来发生了什么? 2. 各类Agent(人群)是如何反应和行动? 3. 这个模拟揭示了哪些值得关注的未来趋势和风险? 【报告定位】 - ✅ 这是一份基于模拟的未来预测报告,揭示"如果这样,未来会怎样" - ✅ 聚焦于预测结果:事件走向、群体反应、涌现现象、潜在风险 - ✅ 模拟世界中的Agent言行就是对未来人群行为的预测 - ❌ 不是对现实世界现状的分析 - ❌ 不是泛泛而谈的舆情综述 【章节数量限制】 - 最少2个章节,最多5个章节 - 不需要子章节,每个章节直接撰写完整内容 - 内容要精炼,聚焦于核心预测发现 - 章节结构由你根据预测结果自主设计 请输出JSON格式的报告大纲,格式如下: { "title": "报告标题", "summary": "报告摘要(一句话概括核心预测发现)", "sections": [ { "title": "章节标题", "description": "章节内容描述" } ] } 注意:sections数组最少2个,最多5个元素!""" PLAN_USER_PROMPT_TEMPLATE = """\ 【预测场景设定】 我们向模拟世界注入的变量(模拟需求):{simulation_requirement} 【模拟世界规模】 - 参与模拟的实体数量: {total_nodes} - 实体间产生的关系数量: {total_edges} - 实体类型分布: {entity_types} - 活跃Agent数量: {total_entities} 【模拟预测到的部分未来事实样本】 {related_facts_json} 请以「上帝视角」审视这个未来预演: 1. 在我们设定的条件下,未来呈现出了什么样的状态? 2. 各类人群(Agent)是如何反应和行动的? 3. 这个模拟揭示了哪些值得关注的未来趋势? 根据预测结果,设计最合适的报告章节结构。 【再次提醒】报告章节数量:最少2个,最多5个,内容要精炼聚焦于核心预测发现。""" # ── 章节生成 prompt ── SECTION_SYSTEM_PROMPT_TEMPLATE = """\ 你是一个「未来预测报告」的撰写专家,正在撰写报告的一个章节。 报告标题: {report_title} 报告摘要: {report_summary} 预测场景(模拟需求): {simulation_requirement} 当前要撰写的章节: {section_title} ═══════════════════════════════════════════════════════════════ 【核心理念】 ═══════════════════════════════════════════════════════════════ 模拟世界是对未来的预演。我们向模拟世界注入了特定条件(模拟需求), 模拟中Agent的行为和互动,就是对未来人群行为的预测。 你的任务是: - 揭示在设定条件下,未来发生了什么 - 预测各类人群(Agent)是如何反应和行动的 - 发现值得关注的未来趋势、风险和机会 ❌ 不要写成对现实世界现状的分析 ✅ 要聚焦于"未来会怎样"——模拟结果就是预测的未来 ═══════════════════════════════════════════════════════════════ 【最重要的规则 - 必须遵守】 ═══════════════════════════════════════════════════════════════ 1. 【必须调用工具观察模拟世界】 - 你正在以「上帝视角」观察未来的预演 - 所有内容必须来自模拟世界中发生的事件和Agent言行 - 禁止使用你自己的知识来编写报告内容 - 每个章节至少调用3次工具(最多5次)来观察模拟的世界,它代表了未来 2. 【必须引用Agent的原始言行】 - Agent的发言和行为是对未来人群行为的预测 - 在报告中使用引用格式展示这些预测,例如: > "某类人群会表示:原文内容..." - 这些引用是模拟预测的核心证据 3. 【语言一致性 - 引用内容必须翻译为报告语言】 - 工具返回的内容可能包含英文或中英文混杂的表述 - 如果模拟需求和材料原文是中文的,报告必须全部使用中文撰写 - 当你引用工具返回的英文或中英混杂内容时,必须将其翻译为流畅的中文后再写入报告 - 翻译时保持原意不变,确保表述自然通顺 - 这一规则同时适用于正文和引用块(> 格式)中的内容 4. 【忠实呈现预测结果】 - 报告内容必须反映模拟世界中的代表未来的模拟结果 - 不要添加模拟中不存在的信息 - 如果某方面信息不足,如实说明 ═══════════════════════════════════════════════════════════════ 【⚠️ 格式规范 - 极其重要!】 ═══════════════════════════════════════════════════════════════ 【一个章节 = 最小内容单位】 - 每个章节是报告的最小分块单位 - ❌ 禁止在章节内使用任何 Markdown 标题(#、##、###、#### 等) - ❌ 禁止在内容开头添加章节主标题 - ✅ 章节标题由系统自动添加,你只需撰写纯正文内容 - ✅ 使用**粗体**、段落分隔、引用、列表来组织内容,但不要用标题 【正确示例】 ``` 本章节分析了事件的舆论传播态势。通过对模拟数据的深入分析,我们发现... **首发引爆阶段** 微博作为舆情的第一现场,承担了信息首发的核心功能: > "微博贡献了68%的首发声量..." **情绪放大阶段** 抖音平台进一步放大了事件影响力: - 视觉冲击力强 - 情绪共鸣度高 ``` 【错误示例】 ``` ## 执行摘要 ← 错误!不要添加任何标题 ### 一、首发阶段 ← 错误!不要用###分小节 #### 1.1 详细分析 ← 错误!不要用####细分 本章节分析了... ``` ═══════════════════════════════════════════════════════════════ 【可用检索工具】(每章节调用3-5次) ═══════════════════════════════════════════════════════════════ {tools_description} 【工具使用建议 - 请混合使用不同工具,不要只用一种】 - insight_forge: 深度洞察分析,自动分解问题并多维度检索事实和关系 - panorama_search: 广角全景搜索,了解事件全貌、时间线和演变过程 - quick_search: 快速验证某个具体信息点 - interview_agents: 采访模拟Agent,获取不同角色的第一人称观点和真实反应 ═══════════════════════════════════════════════════════════════ 【工作流程】 ═══════════════════════════════════════════════════════════════ 每次回复你只能做以下两件事之一(不可同时做): 选项A - 调用工具: 输出你的思考,然后用以下格式调用一个工具: {{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} 系统会执行工具并把结果返回给你。你不需要也不能自己编写工具返回结果。 选项B - 输出最终内容: 当你已通过工具获取了足够信息,以 "Final Answer:" 开头输出章节内容。 ⚠️ 严格禁止: - 禁止在一次回复中同时包含工具调用和 Final Answer - 禁止自己编造工具返回结果(Observation),所有工具结果由系统注入 - 每次回复最多调用一个工具 ═══════════════════════════════════════════════════════════════ 【章节内容要求】 ═══════════════════════════════════════════════════════════════ 1. 内容必须基于工具检索到的模拟数据 2. 大量引用原文来展示模拟效果 3. 使用Markdown格式(但禁止使用标题): - 使用 **粗体文字** 标记重点(代替子标题) - 使用列表(-或1.2.3.)组织要点 - 使用空行分隔不同段落 - ❌ 禁止使用 #、##、###、#### 等任何标题语法 4. 【引用格式规范 - 必须单独成段】 引用必须独立成段,前后各有一个空行,不能混在段落中: ✅ 正确格式: ``` 校方的回应被认为缺乏实质内容。 > "校方的应对模式在瞬息万变的社交媒体环境中显得僵化和迟缓。" 这一评价反映了公众的普遍不满。 ``` ❌ 错误格式: ``` 校方的回应被认为缺乏实质内容。> "校方的应对模式..." 这一评价反映了... ``` 5. 保持与其他章节的逻辑连贯性 6. 【避免重复】仔细阅读下方已完成的章节内容,不要重复描述相同的信息 7. 【再次强调】不要添加任何标题!用**粗体**代替小节标题""" SECTION_USER_PROMPT_TEMPLATE = """\ 已完成的章节内容(请仔细阅读,避免重复): {previous_content} ═══════════════════════════════════════════════════════════════ 【当前任务】撰写章节: {section_title} ═══════════════════════════════════════════════════════════════ 【重要提醒】 1. 仔细阅读上方已完成的章节,避免重复相同的内容! 2. 开始前必须先调用工具获取模拟数据 3. 请混合使用不同工具,不要只用一种 4. 报告内容必须来自检索结果,不要使用自己的知识 【⚠️ 格式警告 - 必须遵守】 - ❌ 不要写任何标题(#、##、###、####都不行) - ❌ 不要写"{section_title}"作为开头 - ✅ 章节标题由系统自动添加 - ✅ 直接写正文,用**粗体**代替小节标题 请开始: 1. 首先思考(Thought)这个章节需要什么信息 2. 然后调用工具(Action)获取模拟数据 3. 收集足够信息后输出 Final Answer(纯正文,无任何标题)""" # ── ReACT 循环内消息模板 ── REACT_OBSERVATION_TEMPLATE = """\ Observation(检索结果): ═══ 工具 {tool_name} 返回 ═══ {result} ═══════════════════════════════════════════════════════════════ 已调用工具 {tool_calls_count}/{max_tool_calls} 次(已用: {used_tools_str}){unused_hint} - 如果信息充分:以 "Final Answer:" 开头输出章节内容(必须引用上述原文) - 如果需要更多信息:调用一个工具继续检索 ═══════════════════════════════════════════════════════════════""" REACT_INSUFFICIENT_TOOLS_MSG = ( "【注意】你只调用了{tool_calls_count}次工具,至少需要{min_tool_calls}次。" "请再调用工具获取更多模拟数据,然后再输出 Final Answer。{unused_hint}" ) REACT_INSUFFICIENT_TOOLS_MSG_ALT = ( "当前只调用了 {tool_calls_count} 次工具,至少需要 {min_tool_calls} 次。" "请调用工具获取模拟数据。{unused_hint}" ) REACT_TOOL_LIMIT_MSG = ( "工具调用次数已达上限({tool_calls_count}/{max_tool_calls}),不能再调用工具。" '请立即基于已获取的信息,以 "Final Answer:" 开头输出章节内容。' ) REACT_UNUSED_TOOLS_HINT = "\n💡 你还没有使用过: {unused_list},建议尝试不同工具获取多角度信息" REACT_FORCE_FINAL_MSG = "已达到工具调用限制,请直接输出 Final Answer: 并生成章节内容。" # ── Chat prompt ── CHAT_SYSTEM_PROMPT_TEMPLATE = """\ 你是一个简洁高效的模拟预测助手。 【背景】 预测条件: {simulation_requirement} 【已生成的分析报告】 {report_content} 【规则】 1. 优先基于上述报告内容回答问题 2. 直接回答问题,避免冗长的思考论述 3. 仅在报告内容不足以回答时,才调用工具检索更多数据 4. 回答要简洁、清晰、有条理 【可用工具】(仅在需要时使用,最多调用1-2次) {tools_description} 【工具调用格式】 {{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} 【回答风格】 - 简洁直接,不要长篇大论 - 使用 > 格式引用关键内容 - 优先给出结论,再解释原因""" CHAT_OBSERVATION_SUFFIX = "\n\n请简洁回答问题。" # ═══════════════════════════════════════════════════════════════ # ReportAgent 主类 # ═══════════════════════════════════════════════════════════════ class ReportAgent: """ Report Agent - 模拟报告生成Agent 采用ReACT(Reasoning + Acting)模式: 1. 规划阶段:分析模拟需求,规划报告目录结构 2. 生成阶段:逐章节生成内容,每章节可多次调用工具获取信息 3. 反思阶段:检查内容完整性和准确性 """ # 最大工具调用次数(每个章节) MAX_TOOL_CALLS_PER_SECTION = 5 # 最大反思轮数 MAX_REFLECTION_ROUNDS = 3 # 对话中的最大工具调用次数 MAX_TOOL_CALLS_PER_CHAT = 2 def __init__( self, graph_id: str, simulation_id: str, simulation_requirement: str, llm_client: Optional[LLMClient] = None, zep_tools: Optional[ZepToolsService] = None ): """ 初始化Report Agent Args: graph_id: 图谱ID simulation_id: 模拟ID simulation_requirement: 模拟需求描述 llm_client: LLM客户端(可选) zep_tools: Zep工具服务(可选) """ self.graph_id = graph_id self.simulation_id = simulation_id self.simulation_requirement = simulation_requirement self.llm = llm_client or LLMClient() self.zep_tools = zep_tools or ZepToolsService() # 工具定义 self.tools = self._define_tools() # 日志记录器(在 generate_report 中初始化) self.report_logger: Optional[ReportLogger] = None # 控制台日志记录器(在 generate_report 中初始化) self.console_logger: Optional[ReportConsoleLogger] = None logger.info(f"ReportAgent 初始化完成: graph_id={graph_id}, simulation_id={simulation_id}") def _define_tools(self) -> Dict[str, Dict[str, Any]]: """定义可用工具""" return { "insight_forge": { "name": "insight_forge", "description": TOOL_DESC_INSIGHT_FORGE, "parameters": { "query": "你想深入分析的问题或话题", "report_context": "当前报告章节的上下文(可选,有助于生成更精准的子问题)" } }, "panorama_search": { "name": "panorama_search", "description": TOOL_DESC_PANORAMA_SEARCH, "parameters": { "query": "搜索查询,用于相关性排序", "include_expired": "是否包含过期/历史内容(默认True)" } }, "quick_search": { "name": "quick_search", "description": TOOL_DESC_QUICK_SEARCH, "parameters": { "query": "搜索查询字符串", "limit": "返回结果数量(可选,默认10)" } }, "interview_agents": { "name": "interview_agents", "description": TOOL_DESC_INTERVIEW_AGENTS, "parameters": { "interview_topic": "采访主题或需求描述(如:'了解学生对宿舍甲醛事件的看法')", "max_agents": "最多采访的Agent数量(可选,默认5,最大10)" } } } def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_context: str = "") -> str: """ 执行工具调用 Args: tool_name: 工具名称 parameters: 工具参数 report_context: 报告上下文(用于InsightForge) Returns: 工具执行结果(文本格式) """ logger.info(f"执行工具: {tool_name}, 参数: {parameters}") try: if tool_name == "insight_forge": query = parameters.get("query", "") ctx = parameters.get("report_context", "") or report_context result = self.zep_tools.insight_forge( graph_id=self.graph_id, query=query, simulation_requirement=self.simulation_requirement, report_context=ctx ) return result.to_text() elif tool_name == "panorama_search": # 广度搜索 - 获取全貌 query = parameters.get("query", "") include_expired = parameters.get("include_expired", True) if isinstance(include_expired, str): include_expired = include_expired.lower() in ['true', '1', 'yes'] result = self.zep_tools.panorama_search( graph_id=self.graph_id, query=query, include_expired=include_expired ) return result.to_text() elif tool_name == "quick_search": # 简单搜索 - 快速检索 query = parameters.get("query", "") limit = parameters.get("limit", 10) if isinstance(limit, str): limit = int(limit) result = self.zep_tools.quick_search( graph_id=self.graph_id, query=query, limit=limit ) return result.to_text() elif tool_name == "interview_agents": # 深度采访 - 调用真实的OASIS采访API获取模拟Agent的回答(双平台) interview_topic = parameters.get("interview_topic", parameters.get("query", "")) max_agents = parameters.get("max_agents", 5) if isinstance(max_agents, str): max_agents = int(max_agents) max_agents = min(max_agents, 10) result = self.zep_tools.interview_agents( simulation_id=self.simulation_id, interview_requirement=interview_topic, simulation_requirement=self.simulation_requirement, max_agents=max_agents ) return result.to_text() # ========== 向后兼容的旧工具(内部重定向到新工具) ========== elif tool_name == "search_graph": # 重定向到 quick_search logger.info("search_graph 已重定向到 quick_search") return self._execute_tool("quick_search", parameters, report_context) elif tool_name == "get_graph_statistics": result = self.zep_tools.get_graph_statistics(self.graph_id) return json.dumps(result, ensure_ascii=False, indent=2) elif tool_name == "get_entity_summary": entity_name = parameters.get("entity_name", "") result = self.zep_tools.get_entity_summary( graph_id=self.graph_id, entity_name=entity_name ) return json.dumps(result, ensure_ascii=False, indent=2) elif tool_name == "get_simulation_context": # 重定向到 insight_forge,因为它更强大 logger.info("get_simulation_context 已重定向到 insight_forge") query = parameters.get("query", self.simulation_requirement) return self._execute_tool("insight_forge", {"query": query}, report_context) elif tool_name == "get_entities_by_type": entity_type = parameters.get("entity_type", "") nodes = self.zep_tools.get_entities_by_type( graph_id=self.graph_id, entity_type=entity_type ) result = [n.to_dict() for n in nodes] return json.dumps(result, ensure_ascii=False, indent=2) else: return f"未知工具: {tool_name}。请使用以下工具之一: insight_forge, panorama_search, quick_search" except Exception as e: logger.error(f"工具执行失败: {tool_name}, 错误: {str(e)}") return f"工具执行失败: {str(e)}" # 合法的工具名称集合,用于裸 JSON 兜底解析时校验 VALID_TOOL_NAMES = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: """ 从LLM响应中解析工具调用 支持的格式(按优先级): 1. {"name": "tool_name", "parameters": {...}} 2. 裸 JSON(响应整体或单行就是一个工具调用 JSON) """ tool_calls = [] # 格式1: XML风格(标准格式) xml_pattern = r'\s*(\{.*?\})\s*' for match in re.finditer(xml_pattern, response, re.DOTALL): try: call_data = json.loads(match.group(1)) tool_calls.append(call_data) except json.JSONDecodeError: pass if tool_calls: return tool_calls # 格式2: 兜底 - LLM 直接输出裸 JSON(没包 标签) # 只在格式1未匹配时尝试,避免误匹配正文中的 JSON stripped = response.strip() if stripped.startswith('{') and stripped.endswith('}'): try: call_data = json.loads(stripped) if self._is_valid_tool_call(call_data): tool_calls.append(call_data) return tool_calls except json.JSONDecodeError: pass # 响应可能包含思考文字 + 裸 JSON,尝试提取最后一个 JSON 对象 json_pattern = r'(\{"(?:name|tool)"\s*:.*?\})\s*$' match = re.search(json_pattern, stripped, re.DOTALL) if match: try: call_data = json.loads(match.group(1)) if self._is_valid_tool_call(call_data): tool_calls.append(call_data) except json.JSONDecodeError: pass return tool_calls def _is_valid_tool_call(self, data: dict) -> bool: """校验解析出的 JSON 是否是合法的工具调用""" # 支持 {"name": ..., "parameters": ...} 和 {"tool": ..., "params": ...} 两种键名 tool_name = data.get("name") or data.get("tool") if tool_name and tool_name in self.VALID_TOOL_NAMES: # 统一键名为 name / parameters if "tool" in data: data["name"] = data.pop("tool") if "params" in data and "parameters" not in data: data["parameters"] = data.pop("params") return True return False def _get_tools_description(self) -> str: """生成工具描述文本""" desc_parts = ["可用工具:"] for name, tool in self.tools.items(): params_desc = ", ".join([f"{k}: {v}" for k, v in tool["parameters"].items()]) desc_parts.append(f"- {name}: {tool['description']}") if params_desc: desc_parts.append(f" 参数: {params_desc}") return "\n".join(desc_parts) def plan_outline( self, progress_callback: Optional[Callable] = None ) -> ReportOutline: """ 规划报告大纲 使用LLM分析模拟需求,规划报告的目录结构 Args: progress_callback: 进度回调函数 Returns: ReportOutline: 报告大纲 """ logger.info("开始规划报告大纲...") if progress_callback: progress_callback("planning", 0, "正在分析模拟需求...") # 首先获取模拟上下文 context = self.zep_tools.get_simulation_context( graph_id=self.graph_id, simulation_requirement=self.simulation_requirement ) if progress_callback: progress_callback("planning", 30, "正在生成报告大纲...") system_prompt = PLAN_SYSTEM_PROMPT user_prompt = PLAN_USER_PROMPT_TEMPLATE.format( simulation_requirement=self.simulation_requirement, total_nodes=context.get('graph_statistics', {}).get('total_nodes', 0), total_edges=context.get('graph_statistics', {}).get('total_edges', 0), entity_types=list(context.get('graph_statistics', {}).get('entity_types', {}).keys()), total_entities=context.get('total_entities', 0), related_facts_json=json.dumps(context.get('related_facts', [])[:10], ensure_ascii=False, indent=2), ) try: response = self.llm.chat_json( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], temperature=0.3 ) if progress_callback: progress_callback("planning", 80, "正在解析大纲结构...") # 解析大纲 sections = [] for section_data in response.get("sections", []): sections.append(ReportSection( title=section_data.get("title", ""), content="" )) outline = ReportOutline( title=response.get("title", "模拟分析报告"), summary=response.get("summary", ""), sections=sections ) if progress_callback: progress_callback("planning", 100, "大纲规划完成") logger.info(f"大纲规划完成: {len(sections)} 个章节") return outline except Exception as e: logger.error(f"大纲规划失败: {str(e)}") # 返回默认大纲(3个章节,作为fallback) return ReportOutline( title="未来预测报告", summary="基于模拟预测的未来趋势与风险分析", sections=[ ReportSection(title="预测场景与核心发现"), ReportSection(title="人群行为预测分析"), ReportSection(title="趋势展望与风险提示") ] ) def _generate_section_react( self, section: ReportSection, outline: ReportOutline, previous_sections: List[str], progress_callback: Optional[Callable] = None, section_index: int = 0 ) -> str: """ 使用ReACT模式生成单个章节内容 ReACT循环: 1. Thought(思考)- 分析需要什么信息 2. Action(行动)- 调用工具获取信息 3. Observation(观察)- 分析工具返回结果 4. 重复直到信息足够或达到最大次数 5. Final Answer(最终回答)- 生成章节内容 Args: section: 要生成的章节 outline: 完整大纲 previous_sections: 之前章节的内容(用于保持连贯性) progress_callback: 进度回调 section_index: 章节索引(用于日志记录) Returns: 章节内容(Markdown格式) """ logger.info(f"ReACT生成章节: {section.title}") # 记录章节开始日志 if self.report_logger: self.report_logger.log_section_start(section.title, section_index) system_prompt = SECTION_SYSTEM_PROMPT_TEMPLATE.format( report_title=outline.title, report_summary=outline.summary, simulation_requirement=self.simulation_requirement, section_title=section.title, tools_description=self._get_tools_description(), ) # 构建用户prompt - 每个已完成章节各传入最大4000字 if previous_sections: previous_parts = [] for sec in previous_sections: # 每个章节最多4000字 truncated = sec[:4000] + "..." if len(sec) > 4000 else sec previous_parts.append(truncated) previous_content = "\n\n---\n\n".join(previous_parts) else: previous_content = "(这是第一个章节)" user_prompt = SECTION_USER_PROMPT_TEMPLATE.format( previous_content=previous_content, section_title=section.title, ) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] # ReACT循环 tool_calls_count = 0 max_iterations = 5 # 最大迭代轮数 min_tool_calls = 3 # 最少工具调用次数 conflict_retries = 0 # 工具调用与Final Answer同时出现的连续冲突次数 used_tools = set() # 记录已调用过的工具名 all_tools = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} # 报告上下文,用于InsightForge的子问题生成 report_context = f"章节标题: {section.title}\n模拟需求: {self.simulation_requirement}" for iteration in range(max_iterations): if progress_callback: progress_callback( "generating", int((iteration / max_iterations) * 100), f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})" ) # 调用LLM response = self.llm.chat( messages=messages, temperature=0.5, max_tokens=4096 ) # 检查 LLM 返回是否为 None(API 异常或内容为空) if response is None: logger.warning(f"章节 {section.title} 第 {iteration + 1} 次迭代: LLM 返回 None") # 如果还有迭代次数,添加消息并重试 if iteration < max_iterations - 1: messages.append({"role": "assistant", "content": "(响应为空)"}) messages.append({"role": "user", "content": "请继续生成内容。"}) continue # 最后一次迭代也返回 None,跳出循环进入强制收尾 break logger.debug(f"LLM响应: {response[:200]}...") # 解析一次,复用结果 tool_calls = self._parse_tool_calls(response) has_tool_calls = bool(tool_calls) has_final_answer = "Final Answer:" in response # ── 冲突处理:LLM 同时输出了工具调用和 Final Answer ── if has_tool_calls and has_final_answer: conflict_retries += 1 logger.warning( f"章节 {section.title} 第 {iteration+1} 轮: " f"LLM 同时输出工具调用和 Final Answer(第 {conflict_retries} 次冲突)" ) if conflict_retries <= 2: # 前两次:丢弃本次响应,要求 LLM 重新回复 messages.append({"role": "assistant", "content": response}) messages.append({ "role": "user", "content": ( "【格式错误】你在一次回复中同时包含了工具调用和 Final Answer,这是不允许的。\n" "每次回复只能做以下两件事之一:\n" "- 调用一个工具(输出一个 块,不要写 Final Answer)\n" "- 输出最终内容(以 'Final Answer:' 开头,不要包含 )\n" "请重新回复,只做其中一件事。" ), }) continue else: # 第三次:降级处理,截断到第一个工具调用,强制执行 logger.warning( f"章节 {section.title}: 连续 {conflict_retries} 次冲突," "降级为截断执行第一个工具调用" ) first_tool_end = response.find('') if first_tool_end != -1: response = response[:first_tool_end + len('')] tool_calls = self._parse_tool_calls(response) has_tool_calls = bool(tool_calls) has_final_answer = False conflict_retries = 0 # 记录 LLM 响应日志 if self.report_logger: self.report_logger.log_llm_response( section_title=section.title, section_index=section_index, response=response, iteration=iteration + 1, has_tool_calls=has_tool_calls, has_final_answer=has_final_answer ) # ── 情况1:LLM 输出了 Final Answer ── if has_final_answer: # 工具调用次数不足,拒绝并要求继续调工具 if tool_calls_count < min_tool_calls: messages.append({"role": "assistant", "content": response}) unused_tools = all_tools - used_tools unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else "" messages.append({ "role": "user", "content": REACT_INSUFFICIENT_TOOLS_MSG.format( tool_calls_count=tool_calls_count, min_tool_calls=min_tool_calls, unused_hint=unused_hint, ), }) continue # 正常结束 final_answer = response.split("Final Answer:")[-1].strip() logger.info(f"章节 {section.title} 生成完成(工具调用: {tool_calls_count}次)") if self.report_logger: self.report_logger.log_section_content( section_title=section.title, section_index=section_index, content=final_answer, tool_calls_count=tool_calls_count ) return final_answer # ── 情况2:LLM 尝试调用工具 ── if has_tool_calls: # 工具额度已耗尽 → 明确告知,要求输出 Final Answer if tool_calls_count >= self.MAX_TOOL_CALLS_PER_SECTION: messages.append({"role": "assistant", "content": response}) messages.append({ "role": "user", "content": REACT_TOOL_LIMIT_MSG.format( tool_calls_count=tool_calls_count, max_tool_calls=self.MAX_TOOL_CALLS_PER_SECTION, ), }) continue # 只执行第一个工具调用 call = tool_calls[0] if len(tool_calls) > 1: logger.info(f"LLM 尝试调用 {len(tool_calls)} 个工具,只执行第一个: {call['name']}") if self.report_logger: self.report_logger.log_tool_call( section_title=section.title, section_index=section_index, tool_name=call["name"], parameters=call.get("parameters", {}), iteration=iteration + 1 ) result = self._execute_tool( call["name"], call.get("parameters", {}), report_context=report_context ) if self.report_logger: self.report_logger.log_tool_result( section_title=section.title, section_index=section_index, tool_name=call["name"], result=result, iteration=iteration + 1 ) tool_calls_count += 1 used_tools.add(call['name']) # 构建未使用工具提示 unused_tools = all_tools - used_tools unused_hint = "" if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION: unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools)) messages.append({"role": "assistant", "content": response}) messages.append({ "role": "user", "content": REACT_OBSERVATION_TEMPLATE.format( tool_name=call["name"], result=result, tool_calls_count=tool_calls_count, max_tool_calls=self.MAX_TOOL_CALLS_PER_SECTION, used_tools_str=", ".join(used_tools), unused_hint=unused_hint, ), }) continue # ── 情况3:既没有工具调用,也没有 Final Answer ── messages.append({"role": "assistant", "content": response}) if tool_calls_count < min_tool_calls: # 工具调用次数不足,推荐未用过的工具 unused_tools = all_tools - used_tools unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else "" messages.append({ "role": "user", "content": REACT_INSUFFICIENT_TOOLS_MSG_ALT.format( tool_calls_count=tool_calls_count, min_tool_calls=min_tool_calls, unused_hint=unused_hint, ), }) continue # 工具调用已足够,LLM 输出了内容但没带 "Final Answer:" 前缀 # 直接将这段内容作为最终答案,不再空转 logger.info(f"章节 {section.title} 未检测到 'Final Answer:' 前缀,直接采纳LLM输出作为最终内容(工具调用: {tool_calls_count}次)") final_answer = response.strip() if self.report_logger: self.report_logger.log_section_content( section_title=section.title, section_index=section_index, content=final_answer, tool_calls_count=tool_calls_count ) return final_answer # 达到最大迭代次数,强制生成内容 logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成") messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG}) response = self.llm.chat( messages=messages, temperature=0.5, max_tokens=4096 ) # 检查强制收尾时 LLM 返回是否为 None if response is None: logger.error(f"章节 {section.title} 强制收尾时 LLM 返回 None,使用默认错误提示") final_answer = f"(本章节生成失败:LLM 返回空响应,请稍后重试)" elif "Final Answer:" in response: final_answer = response.split("Final Answer:")[-1].strip() else: final_answer = response # 记录章节内容生成完成日志 if self.report_logger: self.report_logger.log_section_content( section_title=section.title, section_index=section_index, content=final_answer, tool_calls_count=tool_calls_count ) return final_answer def generate_report( self, progress_callback: Optional[Callable[[str, int, str], None]] = None, report_id: Optional[str] = None ) -> Report: """ 生成完整报告(分章节实时输出) 每个章节生成完成后立即保存到文件夹,不需要等待整个报告完成。 文件结构: reports/{report_id}/ meta.json - 报告元信息 outline.json - 报告大纲 progress.json - 生成进度 section_01.md - 第1章节 section_02.md - 第2章节 ... full_report.md - 完整报告 Args: progress_callback: 进度回调函数 (stage, progress, message) report_id: 报告ID(可选,如果不传则自动生成) Returns: Report: 完整报告 """ import uuid # 如果没有传入 report_id,则自动生成 if not report_id: report_id = f"report_{uuid.uuid4().hex[:12]}" start_time = datetime.now() report = Report( report_id=report_id, simulation_id=self.simulation_id, graph_id=self.graph_id, simulation_requirement=self.simulation_requirement, status=ReportStatus.PENDING, created_at=datetime.now().isoformat() ) # 已完成的章节标题列表(用于进度追踪) completed_section_titles = [] try: # 初始化:创建报告文件夹并保存初始状态 ReportManager._ensure_report_folder(report_id) # 初始化日志记录器(结构化日志 agent_log.jsonl) self.report_logger = ReportLogger(report_id) self.report_logger.log_start( simulation_id=self.simulation_id, graph_id=self.graph_id, simulation_requirement=self.simulation_requirement ) # 初始化控制台日志记录器(console_log.txt) self.console_logger = ReportConsoleLogger(report_id) ReportManager.update_progress( report_id, "pending", 0, "初始化报告...", completed_sections=[] ) ReportManager.save_report(report) # 阶段1: 规划大纲 report.status = ReportStatus.PLANNING ReportManager.update_progress( report_id, "planning", 5, "开始规划报告大纲...", completed_sections=[] ) # 记录规划开始日志 self.report_logger.log_planning_start() if progress_callback: progress_callback("planning", 0, "开始规划报告大纲...") outline = self.plan_outline( progress_callback=lambda stage, prog, msg: progress_callback(stage, prog // 5, msg) if progress_callback else None ) report.outline = outline # 记录规划完成日志 self.report_logger.log_planning_complete(outline.to_dict()) # 保存大纲到文件 ReportManager.save_outline(report_id, outline) ReportManager.update_progress( report_id, "planning", 15, f"大纲规划完成,共{len(outline.sections)}个章节", completed_sections=[] ) ReportManager.save_report(report) logger.info(f"大纲已保存到文件: {report_id}/outline.json") # 阶段2: 逐章节生成(分章节保存) report.status = ReportStatus.GENERATING total_sections = len(outline.sections) generated_sections = [] # 保存内容用于上下文 for i, section in enumerate(outline.sections): section_num = i + 1 base_progress = 20 + int((i / total_sections) * 70) # 更新进度 ReportManager.update_progress( report_id, "generating", base_progress, f"正在生成章节: {section.title} ({section_num}/{total_sections})", current_section=section.title, completed_sections=completed_section_titles ) if progress_callback: progress_callback( "generating", base_progress, f"正在生成章节: {section.title} ({section_num}/{total_sections})" ) # 生成主章节内容 section_content = self._generate_section_react( section=section, outline=outline, previous_sections=generated_sections, progress_callback=lambda stage, prog, msg: progress_callback( stage, base_progress + int(prog * 0.7 / total_sections), msg ) if progress_callback else None, section_index=section_num ) section.content = section_content generated_sections.append(f"## {section.title}\n\n{section_content}") # 保存章节 ReportManager.save_section(report_id, section_num, section) completed_section_titles.append(section.title) # 记录章节完成日志 full_section_content = f"## {section.title}\n\n{section_content}" if self.report_logger: self.report_logger.log_section_full_complete( section_title=section.title, section_index=section_num, full_content=full_section_content.strip() ) logger.info(f"章节已保存: {report_id}/section_{section_num:02d}.md") # 更新进度 ReportManager.update_progress( report_id, "generating", base_progress + int(70 / total_sections), f"章节 {section.title} 已完成", current_section=None, completed_sections=completed_section_titles ) # 阶段3: 组装完整报告 if progress_callback: progress_callback("generating", 95, "正在组装完整报告...") ReportManager.update_progress( report_id, "generating", 95, "正在组装完整报告...", completed_sections=completed_section_titles ) # 使用ReportManager组装完整报告 report.markdown_content = ReportManager.assemble_full_report(report_id, outline) report.status = ReportStatus.COMPLETED report.completed_at = datetime.now().isoformat() # 计算总耗时 total_time_seconds = (datetime.now() - start_time).total_seconds() # 记录报告完成日志 if self.report_logger: self.report_logger.log_report_complete( total_sections=total_sections, total_time_seconds=total_time_seconds ) # 保存最终报告 ReportManager.save_report(report) ReportManager.update_progress( report_id, "completed", 100, "报告生成完成", completed_sections=completed_section_titles ) if progress_callback: progress_callback("completed", 100, "报告生成完成") logger.info(f"报告生成完成: {report_id}") # 关闭控制台日志记录器 if self.console_logger: self.console_logger.close() self.console_logger = None return report except Exception as e: logger.error(f"报告生成失败: {str(e)}") report.status = ReportStatus.FAILED report.error = str(e) # 记录错误日志 if self.report_logger: self.report_logger.log_error(str(e), "failed") # 保存失败状态 try: ReportManager.save_report(report) ReportManager.update_progress( report_id, "failed", -1, f"报告生成失败: {str(e)}", completed_sections=completed_section_titles ) except Exception: pass # 忽略保存失败的错误 # 关闭控制台日志记录器 if self.console_logger: self.console_logger.close() self.console_logger = None return report def chat( self, message: str, chat_history: List[Dict[str, str]] = None ) -> Dict[str, Any]: """ 与Report Agent对话 在对话中Agent可以自主调用检索工具来回答问题 Args: message: 用户消息 chat_history: 对话历史 Returns: { "response": "Agent回复", "tool_calls": [调用的工具列表], "sources": [信息来源] } """ logger.info(f"Report Agent对话: {message[:50]}...") chat_history = chat_history or [] # 获取已生成的报告内容 report_content = "" try: report = ReportManager.get_report_by_simulation(self.simulation_id) if report and report.markdown_content: # 限制报告长度,避免上下文过长 report_content = report.markdown_content[:15000] if len(report.markdown_content) > 15000: report_content += "\n\n... [报告内容已截断] ..." except Exception as e: logger.warning(f"获取报告内容失败: {e}") system_prompt = CHAT_SYSTEM_PROMPT_TEMPLATE.format( simulation_requirement=self.simulation_requirement, report_content=report_content if report_content else "(暂无报告)", tools_description=self._get_tools_description(), ) # 构建消息 messages = [{"role": "system", "content": system_prompt}] # 添加历史对话 for h in chat_history[-10:]: # 限制历史长度 messages.append(h) # 添加用户消息 messages.append({ "role": "user", "content": message }) # ReACT循环(简化版) tool_calls_made = [] max_iterations = 2 # 减少迭代轮数 for iteration in range(max_iterations): response = self.llm.chat( messages=messages, temperature=0.5 ) # 解析工具调用 tool_calls = self._parse_tool_calls(response) if not tool_calls: # 没有工具调用,直接返回响应 clean_response = re.sub(r'.*?', '', response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) return { "response": clean_response.strip(), "tool_calls": tool_calls_made, "sources": [tc.get("parameters", {}).get("query", "") for tc in tool_calls_made] } # 执行工具调用(限制数量) tool_results = [] for call in tool_calls[:1]: # 每轮最多执行1次工具调用 if len(tool_calls_made) >= self.MAX_TOOL_CALLS_PER_CHAT: break result = self._execute_tool(call["name"], call.get("parameters", {})) tool_results.append({ "tool": call["name"], "result": result[:1500] # 限制结果长度 }) tool_calls_made.append(call) # 将结果添加到消息 messages.append({"role": "assistant", "content": response}) observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results]) messages.append({ "role": "user", "content": observation + CHAT_OBSERVATION_SUFFIX }) # 达到最大迭代,获取最终响应 final_response = self.llm.chat( messages=messages, temperature=0.5 ) # 清理响应 clean_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) return { "response": clean_response.strip(), "tool_calls": tool_calls_made, "sources": [tc.get("parameters", {}).get("query", "") for tc in tool_calls_made] } class ReportManager: """ 报告管理器 负责报告的持久化存储和检索 文件结构(分章节输出): reports/ {report_id}/ meta.json - 报告元信息和状态 outline.json - 报告大纲 progress.json - 生成进度 section_01.md - 第1章节 section_02.md - 第2章节 ... full_report.md - 完整报告 """ # 报告存储目录 REPORTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'reports') @classmethod def _ensure_reports_dir(cls): """确保报告根目录存在""" os.makedirs(cls.REPORTS_DIR, exist_ok=True) @classmethod def _get_report_folder(cls, report_id: str) -> str: """获取报告文件夹路径""" return os.path.join(cls.REPORTS_DIR, report_id) @classmethod def _ensure_report_folder(cls, report_id: str) -> str: """确保报告文件夹存在并返回路径""" folder = cls._get_report_folder(report_id) os.makedirs(folder, exist_ok=True) return folder @classmethod def _get_report_path(cls, report_id: str) -> str: """获取报告元信息文件路径""" return os.path.join(cls._get_report_folder(report_id), "meta.json") @classmethod def _get_report_markdown_path(cls, report_id: str) -> str: """获取完整报告Markdown文件路径""" return os.path.join(cls._get_report_folder(report_id), "full_report.md") @classmethod def _get_outline_path(cls, report_id: str) -> str: """获取大纲文件路径""" return os.path.join(cls._get_report_folder(report_id), "outline.json") @classmethod def _get_progress_path(cls, report_id: str) -> str: """获取进度文件路径""" return os.path.join(cls._get_report_folder(report_id), "progress.json") @classmethod def _get_section_path(cls, report_id: str, section_index: int) -> str: """获取章节Markdown文件路径""" return os.path.join(cls._get_report_folder(report_id), f"section_{section_index:02d}.md") @classmethod def _get_agent_log_path(cls, report_id: str) -> str: """获取 Agent 日志文件路径""" return os.path.join(cls._get_report_folder(report_id), "agent_log.jsonl") @classmethod def _get_console_log_path(cls, report_id: str) -> str: """获取控制台日志文件路径""" return os.path.join(cls._get_report_folder(report_id), "console_log.txt") @classmethod def get_console_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: """ 获取控制台日志内容 这是报告生成过程中的控制台输出日志(INFO、WARNING等), 与 agent_log.jsonl 的结构化日志不同。 Args: report_id: 报告ID from_line: 从第几行开始读取(用于增量获取,0 表示从头开始) Returns: { "logs": [日志行列表], "total_lines": 总行数, "from_line": 起始行号, "has_more": 是否还有更多日志 } """ log_path = cls._get_console_log_path(report_id) if not os.path.exists(log_path): return { "logs": [], "total_lines": 0, "from_line": 0, "has_more": False } logs = [] total_lines = 0 with open(log_path, 'r', encoding='utf-8') as f: for i, line in enumerate(f): total_lines = i + 1 if i >= from_line: # 保留原始日志行,去掉末尾换行符 logs.append(line.rstrip('\n\r')) return { "logs": logs, "total_lines": total_lines, "from_line": from_line, "has_more": False # 已读取到末尾 } @classmethod def get_console_log_stream(cls, report_id: str) -> List[str]: """ 获取完整的控制台日志(一次性获取全部) Args: report_id: 报告ID Returns: 日志行列表 """ result = cls.get_console_log(report_id, from_line=0) return result["logs"] @classmethod def get_agent_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: """ 获取 Agent 日志内容 Args: report_id: 报告ID from_line: 从第几行开始读取(用于增量获取,0 表示从头开始) Returns: { "logs": [日志条目列表], "total_lines": 总行数, "from_line": 起始行号, "has_more": 是否还有更多日志 } """ log_path = cls._get_agent_log_path(report_id) if not os.path.exists(log_path): return { "logs": [], "total_lines": 0, "from_line": 0, "has_more": False } logs = [] total_lines = 0 with open(log_path, 'r', encoding='utf-8') as f: for i, line in enumerate(f): total_lines = i + 1 if i >= from_line: try: log_entry = json.loads(line.strip()) logs.append(log_entry) except json.JSONDecodeError: # 跳过解析失败的行 continue return { "logs": logs, "total_lines": total_lines, "from_line": from_line, "has_more": False # 已读取到末尾 } @classmethod def get_agent_log_stream(cls, report_id: str) -> List[Dict[str, Any]]: """ 获取完整的 Agent 日志(用于一次性获取全部) Args: report_id: 报告ID Returns: 日志条目列表 """ result = cls.get_agent_log(report_id, from_line=0) return result["logs"] @classmethod def save_outline(cls, report_id: str, outline: ReportOutline) -> None: """ 保存报告大纲 在规划阶段完成后立即调用 """ cls._ensure_report_folder(report_id) with open(cls._get_outline_path(report_id), 'w', encoding='utf-8') as f: json.dump(outline.to_dict(), f, ensure_ascii=False, indent=2) logger.info(f"大纲已保存: {report_id}") @classmethod def save_section( cls, report_id: str, section_index: int, section: ReportSection ) -> str: """ 保存单个章节 在每个章节生成完成后立即调用,实现分章节输出 Args: report_id: 报告ID section_index: 章节索引(从1开始) section: 章节对象 Returns: 保存的文件路径 """ cls._ensure_report_folder(report_id) # 构建章节Markdown内容 - 清理可能存在的重复标题 cleaned_content = cls._clean_section_content(section.content, section.title) md_content = f"## {section.title}\n\n" if cleaned_content: md_content += f"{cleaned_content}\n\n" # 保存文件 file_suffix = f"section_{section_index:02d}.md" file_path = os.path.join(cls._get_report_folder(report_id), file_suffix) with open(file_path, 'w', encoding='utf-8') as f: f.write(md_content) logger.info(f"章节已保存: {report_id}/{file_suffix}") return file_path @classmethod def _clean_section_content(cls, content: str, section_title: str) -> str: """ 清理章节内容 1. 移除内容开头与章节标题重复的Markdown标题行 2. 将所有 ### 及以下级别的标题转换为粗体文本 Args: content: 原始内容 section_title: 章节标题 Returns: 清理后的内容 """ import re if not content: return content content = content.strip() lines = content.split('\n') cleaned_lines = [] skip_next_empty = False for i, line in enumerate(lines): stripped = line.strip() # 检查是否是Markdown标题行 heading_match = re.match(r'^(#{1,6})\s+(.+)$', stripped) if heading_match: level = len(heading_match.group(1)) title_text = heading_match.group(2).strip() # 检查是否是与章节标题重复的标题(跳过前5行内的重复) if i < 5: if title_text == section_title or title_text.replace(' ', '') == section_title.replace(' ', ''): skip_next_empty = True continue # 将所有级别的标题(#, ##, ###, ####等)转换为粗体 # 因为章节标题由系统添加,内容中不应有任何标题 cleaned_lines.append(f"**{title_text}**") cleaned_lines.append("") # 添加空行 continue # 如果上一行是被跳过的标题,且当前行为空,也跳过 if skip_next_empty and stripped == '': skip_next_empty = False continue skip_next_empty = False cleaned_lines.append(line) # 移除开头的空行 while cleaned_lines and cleaned_lines[0].strip() == '': cleaned_lines.pop(0) # 移除开头的分隔线 while cleaned_lines and cleaned_lines[0].strip() in ['---', '***', '___']: cleaned_lines.pop(0) # 同时移除分隔线后的空行 while cleaned_lines and cleaned_lines[0].strip() == '': cleaned_lines.pop(0) return '\n'.join(cleaned_lines) @classmethod def update_progress( cls, report_id: str, status: str, progress: int, message: str, current_section: str = None, completed_sections: List[str] = None ) -> None: """ 更新报告生成进度 前端可以通过读取progress.json获取实时进度 """ cls._ensure_report_folder(report_id) progress_data = { "status": status, "progress": progress, "message": message, "current_section": current_section, "completed_sections": completed_sections or [], "updated_at": datetime.now().isoformat() } with open(cls._get_progress_path(report_id), 'w', encoding='utf-8') as f: json.dump(progress_data, f, ensure_ascii=False, indent=2) @classmethod def get_progress(cls, report_id: str) -> Optional[Dict[str, Any]]: """获取报告生成进度""" path = cls._get_progress_path(report_id) if not os.path.exists(path): return None with open(path, 'r', encoding='utf-8') as f: return json.load(f) @classmethod def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: """ 获取已生成的章节列表 返回所有已保存的章节文件信息 """ folder = cls._get_report_folder(report_id) if not os.path.exists(folder): return [] sections = [] for filename in sorted(os.listdir(folder)): if filename.startswith('section_') and filename.endswith('.md'): file_path = os.path.join(folder, filename) with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # 从文件名解析章节索引 parts = filename.replace('.md', '').split('_') section_index = int(parts[1]) sections.append({ "filename": filename, "section_index": section_index, "content": content }) return sections @classmethod def assemble_full_report(cls, report_id: str, outline: ReportOutline) -> str: """ 组装完整报告 从已保存的章节文件组装完整报告,并进行标题清理 """ folder = cls._get_report_folder(report_id) # 构建报告头部 md_content = f"# {outline.title}\n\n" md_content += f"> {outline.summary}\n\n" md_content += f"---\n\n" # 按顺序读取所有章节文件 sections = cls.get_generated_sections(report_id) for section_info in sections: md_content += section_info["content"] # 后处理:清理整个报告的标题问题 md_content = cls._post_process_report(md_content, outline) # 保存完整报告 full_path = cls._get_report_markdown_path(report_id) with open(full_path, 'w', encoding='utf-8') as f: f.write(md_content) logger.info(f"完整报告已组装: {report_id}") return md_content @classmethod def _post_process_report(cls, content: str, outline: ReportOutline) -> str: """ 后处理报告内容 1. 移除重复的标题 2. 保留报告主标题(#)和章节标题(##),移除其他级别的标题(###, ####等) 3. 清理多余的空行和分隔线 Args: content: 原始报告内容 outline: 报告大纲 Returns: 处理后的内容 """ import re lines = content.split('\n') processed_lines = [] prev_was_heading = False # 收集大纲中的所有章节标题 section_titles = set() for section in outline.sections: section_titles.add(section.title) i = 0 while i < len(lines): line = lines[i] stripped = line.strip() # 检查是否是标题行 heading_match = re.match(r'^(#{1,6})\s+(.+)$', stripped) if heading_match: level = len(heading_match.group(1)) title = heading_match.group(2).strip() # 检查是否是重复标题(在连续5行内出现相同内容的标题) is_duplicate = False for j in range(max(0, len(processed_lines) - 5), len(processed_lines)): prev_line = processed_lines[j].strip() prev_match = re.match(r'^(#{1,6})\s+(.+)$', prev_line) if prev_match: prev_title = prev_match.group(2).strip() if prev_title == title: is_duplicate = True break if is_duplicate: # 跳过重复标题及其后的空行 i += 1 while i < len(lines) and lines[i].strip() == '': i += 1 continue # 标题层级处理: # - # (level=1) 只保留报告主标题 # - ## (level=2) 保留章节标题 # - ### 及以下 (level>=3) 转换为粗体文本 if level == 1: if title == outline.title: # 保留报告主标题 processed_lines.append(line) prev_was_heading = True elif title in section_titles: # 章节标题错误使用了#,修正为## processed_lines.append(f"## {title}") prev_was_heading = True else: # 其他一级标题转为粗体 processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False elif level == 2: if title in section_titles or title == outline.title: # 保留章节标题 processed_lines.append(line) prev_was_heading = True else: # 非章节的二级标题转为粗体 processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False else: # ### 及以下级别的标题转换为粗体文本 processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False i += 1 continue elif stripped == '---' and prev_was_heading: # 跳过标题后紧跟的分隔线 i += 1 continue elif stripped == '' and prev_was_heading: # 标题后只保留一个空行 if processed_lines and processed_lines[-1].strip() != '': processed_lines.append(line) prev_was_heading = False else: processed_lines.append(line) prev_was_heading = False i += 1 # 清理连续的多个空行(保留最多2个) result_lines = [] empty_count = 0 for line in processed_lines: if line.strip() == '': empty_count += 1 if empty_count <= 2: result_lines.append(line) else: empty_count = 0 result_lines.append(line) return '\n'.join(result_lines) @classmethod def save_report(cls, report: Report) -> None: """保存报告元信息和完整报告""" cls._ensure_report_folder(report.report_id) # 保存元信息JSON with open(cls._get_report_path(report.report_id), 'w', encoding='utf-8') as f: json.dump(report.to_dict(), f, ensure_ascii=False, indent=2) # 保存大纲 if report.outline: cls.save_outline(report.report_id, report.outline) # 保存完整Markdown报告 if report.markdown_content: with open(cls._get_report_markdown_path(report.report_id), 'w', encoding='utf-8') as f: f.write(report.markdown_content) logger.info(f"报告已保存: {report.report_id}") @classmethod def get_report(cls, report_id: str) -> Optional[Report]: """获取报告""" path = cls._get_report_path(report_id) if not os.path.exists(path): # 兼容旧格式:检查直接存储在reports目录下的文件 old_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.json") if os.path.exists(old_path): path = old_path else: return None with open(path, 'r', encoding='utf-8') as f: data = json.load(f) # 重建Report对象 outline = None if data.get('outline'): outline_data = data['outline'] sections = [] for s in outline_data.get('sections', []): sections.append(ReportSection( title=s['title'], content=s.get('content', '') )) outline = ReportOutline( title=outline_data['title'], summary=outline_data['summary'], sections=sections ) # 如果markdown_content为空,尝试从full_report.md读取 markdown_content = data.get('markdown_content', '') if not markdown_content: full_report_path = cls._get_report_markdown_path(report_id) if os.path.exists(full_report_path): with open(full_report_path, 'r', encoding='utf-8') as f: markdown_content = f.read() return Report( report_id=data['report_id'], simulation_id=data['simulation_id'], graph_id=data['graph_id'], simulation_requirement=data['simulation_requirement'], status=ReportStatus(data['status']), outline=outline, markdown_content=markdown_content, created_at=data.get('created_at', ''), completed_at=data.get('completed_at', ''), error=data.get('error') ) @classmethod def get_report_by_simulation(cls, simulation_id: str) -> Optional[Report]: """根据模拟ID获取报告""" cls._ensure_reports_dir() for item in os.listdir(cls.REPORTS_DIR): item_path = os.path.join(cls.REPORTS_DIR, item) # 新格式:文件夹 if os.path.isdir(item_path): report = cls.get_report(item) if report and report.simulation_id == simulation_id: return report # 兼容旧格式:JSON文件 elif item.endswith('.json'): report_id = item[:-5] report = cls.get_report(report_id) if report and report.simulation_id == simulation_id: return report return None @classmethod def list_reports(cls, simulation_id: Optional[str] = None, limit: int = 50) -> List[Report]: """列出报告""" cls._ensure_reports_dir() reports = [] for item in os.listdir(cls.REPORTS_DIR): item_path = os.path.join(cls.REPORTS_DIR, item) # 新格式:文件夹 if os.path.isdir(item_path): report = cls.get_report(item) if report: if simulation_id is None or report.simulation_id == simulation_id: reports.append(report) # 兼容旧格式:JSON文件 elif item.endswith('.json'): report_id = item[:-5] report = cls.get_report(report_id) if report: if simulation_id is None or report.simulation_id == simulation_id: reports.append(report) # 按创建时间倒序 reports.sort(key=lambda r: r.created_at, reverse=True) return reports[:limit] @classmethod def delete_report(cls, report_id: str) -> bool: """删除报告(整个文件夹)""" import shutil folder_path = cls._get_report_folder(report_id) # 新格式:删除整个文件夹 if os.path.exists(folder_path) and os.path.isdir(folder_path): shutil.rmtree(folder_path) logger.info(f"报告文件夹已删除: {report_id}") return True # 兼容旧格式:删除单独的文件 deleted = False old_json_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.json") old_md_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.md") if os.path.exists(old_json_path): os.remove(old_json_path) deleted = True if os.path.exists(old_md_path): os.remove(old_md_path) deleted = True return deleted ================================================ FILE: backend/app/services/simulation_config_generator.py ================================================ """ 模拟配置智能生成器 使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数 实现全程自动化,无需人工设置参数 采用分步生成策略,避免一次性生成过长内容导致失败: 1. 生成时间配置 2. 生成事件配置 3. 分批生成Agent配置 4. 生成平台配置 """ import json import math from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass, field, asdict from datetime import datetime from openai import OpenAI from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.simulation_config') # 中国作息时间配置(北京时间) CHINA_TIMEZONE_CONFIG = { # 深夜时段(几乎无人活动) "dead_hours": [0, 1, 2, 3, 4, 5], # 早间时段(逐渐醒来) "morning_hours": [6, 7, 8], # 工作时段 "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], # 晚间高峰(最活跃) "peak_hours": [19, 20, 21, 22], # 夜间时段(活跃度下降) "night_hours": [23], # 活跃度系数 "activity_multipliers": { "dead": 0.05, # 凌晨几乎无人 "morning": 0.4, # 早间逐渐活跃 "work": 0.7, # 工作时段中等 "peak": 1.5, # 晚间高峰 "night": 0.5 # 深夜下降 } } @dataclass class AgentActivityConfig: """单个Agent的活动配置""" agent_id: int entity_uuid: str entity_name: str entity_type: str # 活跃度配置 (0.0-1.0) activity_level: float = 0.5 # 整体活跃度 # 发言频率(每小时预期发言次数) posts_per_hour: float = 1.0 comments_per_hour: float = 2.0 # 活跃时间段(24小时制,0-23) active_hours: List[int] = field(default_factory=lambda: list(range(8, 23))) # 响应速度(对热点事件的反应延迟,单位:模拟分钟) response_delay_min: int = 5 response_delay_max: int = 60 # 情感倾向 (-1.0到1.0,负面到正面) sentiment_bias: float = 0.0 # 立场(对特定话题的态度) stance: str = "neutral" # supportive, opposing, neutral, observer # 影响力权重(决定其发言被其他Agent看到的概率) influence_weight: float = 1.0 @dataclass class TimeSimulationConfig: """时间模拟配置(基于中国人作息习惯)""" # 模拟总时长(模拟小时数) total_simulation_hours: int = 72 # 默认模拟72小时(3天) # 每轮代表的时间(模拟分钟)- 默认60分钟(1小时),加快时间流速 minutes_per_round: int = 60 # 每小时激活的Agent数量范围 agents_per_hour_min: int = 5 agents_per_hour_max: int = 20 # 高峰时段(晚间19-22点,中国人最活跃的时间) peak_hours: List[int] = field(default_factory=lambda: [19, 20, 21, 22]) peak_activity_multiplier: float = 1.5 # 低谷时段(凌晨0-5点,几乎无人活动) off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) off_peak_activity_multiplier: float = 0.05 # 凌晨活跃度极低 # 早间时段 morning_hours: List[int] = field(default_factory=lambda: [6, 7, 8]) morning_activity_multiplier: float = 0.4 # 工作时段 work_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18]) work_activity_multiplier: float = 0.7 @dataclass class EventConfig: """事件配置""" # 初始事件(模拟开始时的触发事件) initial_posts: List[Dict[str, Any]] = field(default_factory=list) # 定时事件(在特定时间触发的事件) scheduled_events: List[Dict[str, Any]] = field(default_factory=list) # 热点话题关键词 hot_topics: List[str] = field(default_factory=list) # 舆论引导方向 narrative_direction: str = "" @dataclass class PlatformConfig: """平台特定配置""" platform: str # twitter or reddit # 推荐算法权重 recency_weight: float = 0.4 # 时间新鲜度 popularity_weight: float = 0.3 # 热度 relevance_weight: float = 0.3 # 相关性 # 病毒传播阈值(达到多少互动后触发扩散) viral_threshold: int = 10 # 回声室效应强度(相似观点聚集程度) echo_chamber_strength: float = 0.5 @dataclass class SimulationParameters: """完整的模拟参数配置""" # 基础信息 simulation_id: str project_id: str graph_id: str simulation_requirement: str # 时间配置 time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig) # Agent配置列表 agent_configs: List[AgentActivityConfig] = field(default_factory=list) # 事件配置 event_config: EventConfig = field(default_factory=EventConfig) # 平台配置 twitter_config: Optional[PlatformConfig] = None reddit_config: Optional[PlatformConfig] = None # LLM配置 llm_model: str = "" llm_base_url: str = "" # 生成元数据 generated_at: str = field(default_factory=lambda: datetime.now().isoformat()) generation_reasoning: str = "" # LLM的推理说明 def to_dict(self) -> Dict[str, Any]: """转换为字典""" time_dict = asdict(self.time_config) return { "simulation_id": self.simulation_id, "project_id": self.project_id, "graph_id": self.graph_id, "simulation_requirement": self.simulation_requirement, "time_config": time_dict, "agent_configs": [asdict(a) for a in self.agent_configs], "event_config": asdict(self.event_config), "twitter_config": asdict(self.twitter_config) if self.twitter_config else None, "reddit_config": asdict(self.reddit_config) if self.reddit_config else None, "llm_model": self.llm_model, "llm_base_url": self.llm_base_url, "generated_at": self.generated_at, "generation_reasoning": self.generation_reasoning, } def to_json(self, indent: int = 2) -> str: """转换为JSON字符串""" return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent) class SimulationConfigGenerator: """ 模拟配置智能生成器 使用LLM分析模拟需求、文档内容、图谱实体信息, 自动生成最佳的模拟参数配置 采用分步生成策略: 1. 生成时间配置和事件配置(轻量级) 2. 分批生成Agent配置(每批10-20个) 3. 生成平台配置 """ # 上下文最大字符数 MAX_CONTEXT_LENGTH = 50000 # 每批生成的Agent数量 AGENTS_PER_BATCH = 15 # 各步骤的上下文截断长度(字符数) TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置 EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置 ENTITY_SUMMARY_LENGTH = 300 # 实体摘要 AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要 ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量 def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: raise ValueError("LLM_API_KEY 未配置") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) def generate_config( self, simulation_id: str, project_id: str, graph_id: str, simulation_requirement: str, document_text: str, entities: List[EntityNode], enable_twitter: bool = True, enable_reddit: bool = True, progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> SimulationParameters: """ 智能生成完整的模拟配置(分步生成) Args: simulation_id: 模拟ID project_id: 项目ID graph_id: 图谱ID simulation_requirement: 模拟需求描述 document_text: 原始文档内容 entities: 过滤后的实体列表 enable_twitter: 是否启用Twitter enable_reddit: 是否启用Reddit progress_callback: 进度回调函数(current_step, total_steps, message) Returns: SimulationParameters: 完整的模拟参数 """ logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}, 实体数={len(entities)}") # 计算总步骤数 num_batches = math.ceil(len(entities) / self.AGENTS_PER_BATCH) total_steps = 3 + num_batches # 时间配置 + 事件配置 + N批Agent + 平台配置 current_step = 0 def report_progress(step: int, message: str): nonlocal current_step current_step = step if progress_callback: progress_callback(step, total_steps, message) logger.info(f"[{step}/{total_steps}] {message}") # 1. 构建基础上下文信息 context = self._build_context( simulation_requirement=simulation_requirement, document_text=document_text, entities=entities ) reasoning_parts = [] # ========== 步骤1: 生成时间配置 ========== report_progress(1, "生成时间配置...") num_entities = len(entities) time_config_result = self._generate_time_config(context, num_entities) time_config = self._parse_time_config(time_config_result, num_entities) reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}") # ========== 步骤2: 生成事件配置 ========== report_progress(2, "生成事件配置和热点话题...") event_config_result = self._generate_event_config(context, simulation_requirement, entities) event_config = self._parse_event_config(event_config_result) reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}") # ========== 步骤3-N: 分批生成Agent配置 ========== all_agent_configs = [] for batch_idx in range(num_batches): start_idx = batch_idx * self.AGENTS_PER_BATCH end_idx = min(start_idx + self.AGENTS_PER_BATCH, len(entities)) batch_entities = entities[start_idx:end_idx] report_progress( 3 + batch_idx, f"生成Agent配置 ({start_idx + 1}-{end_idx}/{len(entities)})..." ) batch_configs = self._generate_agent_configs_batch( context=context, entities=batch_entities, start_idx=start_idx, simulation_requirement=simulation_requirement ) all_agent_configs.extend(batch_configs) reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)} 个") # ========== 为初始帖子分配发布者 Agent ========== logger.info("为初始帖子分配合适的发布者 Agent...") event_config = self._assign_initial_post_agents(event_config, all_agent_configs) assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None]) reasoning_parts.append(f"初始帖子分配: {assigned_count} 个帖子已分配发布者") # ========== 最后一步: 生成平台配置 ========== report_progress(total_steps, "生成平台配置...") twitter_config = None reddit_config = None if enable_twitter: twitter_config = PlatformConfig( platform="twitter", recency_weight=0.4, popularity_weight=0.3, relevance_weight=0.3, viral_threshold=10, echo_chamber_strength=0.5 ) if enable_reddit: reddit_config = PlatformConfig( platform="reddit", recency_weight=0.3, popularity_weight=0.4, relevance_weight=0.3, viral_threshold=15, echo_chamber_strength=0.6 ) # 构建最终参数 params = SimulationParameters( simulation_id=simulation_id, project_id=project_id, graph_id=graph_id, simulation_requirement=simulation_requirement, time_config=time_config, agent_configs=all_agent_configs, event_config=event_config, twitter_config=twitter_config, reddit_config=reddit_config, llm_model=self.model_name, llm_base_url=self.base_url, generation_reasoning=" | ".join(reasoning_parts) ) logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置") return params def _build_context( self, simulation_requirement: str, document_text: str, entities: List[EntityNode] ) -> str: """构建LLM上下文,截断到最大长度""" # 实体摘要 entity_summary = self._summarize_entities(entities) # 构建上下文 context_parts = [ f"## 模拟需求\n{simulation_requirement}", f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}", ] current_length = sum(len(p) for p in context_parts) remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量 if remaining_length > 0 and document_text: doc_text = document_text[:remaining_length] if len(document_text) > remaining_length: doc_text += "\n...(文档已截断)" context_parts.append(f"\n## 原始文档内容\n{doc_text}") return "\n".join(context_parts) def _summarize_entities(self, entities: List[EntityNode]) -> str: """生成实体摘要""" lines = [] # 按类型分组 by_type: Dict[str, List[EntityNode]] = {} for e in entities: t = e.get_entity_type() or "Unknown" if t not in by_type: by_type[t] = [] by_type[t].append(e) for entity_type, type_entities in by_type.items(): lines.append(f"\n### {entity_type} ({len(type_entities)}个)") # 使用配置的显示数量和摘要长度 display_count = self.ENTITIES_PER_TYPE_DISPLAY summary_len = self.ENTITY_SUMMARY_LENGTH for e in type_entities[:display_count]: summary_preview = (e.summary[:summary_len] + "...") if len(e.summary) > summary_len else e.summary lines.append(f"- {e.name}: {summary_preview}") if len(type_entities) > display_count: lines.append(f" ... 还有 {len(type_entities) - display_count} 个") return "\n".join(lines) def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]: """带重试的LLM调用,包含JSON修复逻辑""" import re max_attempts = 3 last_error = None for attempt in range(max_attempts): try: response = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 # 不设置max_tokens,让LLM自由发挥 ) content = response.choices[0].message.content finish_reason = response.choices[0].finish_reason # 检查是否被截断 if finish_reason == 'length': logger.warning(f"LLM输出被截断 (attempt {attempt+1})") content = self._fix_truncated_json(content) # 尝试解析JSON try: return json.loads(content) except json.JSONDecodeError as e: logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}") # 尝试修复JSON fixed = self._try_fix_config_json(content) if fixed: return fixed last_error = e except Exception as e: logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time time.sleep(2 * (attempt + 1)) raise last_error or Exception("LLM调用失败") def _fix_truncated_json(self, content: str) -> str: """修复被截断的JSON""" content = content.strip() # 计算未闭合的括号 open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') # 检查是否有未闭合的字符串 if content and content[-1] not in '",}]': content += '"' # 闭合括号 content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]: """尝试修复配置JSON""" import re # 修复被截断的情况 content = self._fix_truncated_json(content) # 提取JSON部分 json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() # 移除字符串中的换行符 def fix_string(match): s = match.group(0) s = s.replace('\n', ' ').replace('\r', ' ') s = re.sub(r'\s+', ' ', s) return s json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string, json_str) try: return json.loads(json_str) except: # 尝试移除所有控制字符 json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) json_str = re.sub(r'\s+', ' ', json_str) try: return json.loads(json_str) except: pass return None def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]: """生成时间配置""" # 使用配置的上下文截断长度 context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH] # 计算最大允许值(80%的agent数) max_agents_allowed = max(1, int(num_entities * 0.9)) prompt = f"""基于以下模拟需求,生成时间模拟配置。 {context_truncated} ## 任务 请生成时间配置JSON。 ### 基本原则(仅供参考,需根据具体事件和参与群体灵活调整): - 用户群体为中国人,需符合北京时间作息习惯 - 凌晨0-5点几乎无人活动(活跃度系数0.05) - 早上6-8点逐渐活跃(活跃度系数0.4) - 工作时间9-18点中等活跃(活跃度系数0.7) - 晚间19-22点是高峰期(活跃度系数1.5) - 23点后活跃度下降(活跃度系数0.5) - 一般规律:凌晨低活跃、早间渐增、工作时段中等、晚间高峰 - **重要**:以下示例值仅供参考,你需要根据事件性质、参与群体特点来调整具体时段 - 例如:学生群体高峰可能是21-23点;媒体全天活跃;官方机构只在工作时间 - 例如:突发热点可能导致深夜也有讨论,off_peak_hours 可适当缩短 ### 返回JSON格式(不要markdown) 示例: {{ "total_simulation_hours": 72, "minutes_per_round": 60, "agents_per_hour_min": 5, "agents_per_hour_max": 50, "peak_hours": [19, 20, 21, 22], "off_peak_hours": [0, 1, 2, 3, 4, 5], "morning_hours": [6, 7, 8], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "reasoning": "针对该事件的时间配置说明" }} 字段说明: - total_simulation_hours (int): 模拟总时长,24-168小时,突发事件短、持续话题长 - minutes_per_round (int): 每轮时长,30-120分钟,建议60分钟 - agents_per_hour_min (int): 每小时最少激活Agent数(取值范围: 1-{max_agents_allowed}) - agents_per_hour_max (int): 每小时最多激活Agent数(取值范围: 1-{max_agents_allowed}) - peak_hours (int数组): 高峰时段,根据事件参与群体调整 - off_peak_hours (int数组): 低谷时段,通常深夜凌晨 - morning_hours (int数组): 早间时段 - work_hours (int数组): 工作时段 - reasoning (string): 简要说明为什么这样配置""" system_prompt = "你是社交媒体模拟专家。返回纯JSON格式,时间配置需符合中国人作息习惯。" try: return self._call_llm_with_retry(prompt, system_prompt) except Exception as e: logger.warning(f"时间配置LLM生成失败: {e}, 使用默认配置") return self._get_default_time_config(num_entities) def _get_default_time_config(self, num_entities: int) -> Dict[str, Any]: """获取默认时间配置(中国人作息)""" return { "total_simulation_hours": 72, "minutes_per_round": 60, # 每轮1小时,加快时间流速 "agents_per_hour_min": max(1, num_entities // 15), "agents_per_hour_max": max(5, num_entities // 5), "peak_hours": [19, 20, 21, 22], "off_peak_hours": [0, 1, 2, 3, 4, 5], "morning_hours": [6, 7, 8], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "reasoning": "使用默认中国人作息配置(每轮1小时)" } def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig: """解析时间配置结果,并验证agents_per_hour值不超过总agent数""" # 获取原始值 agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15)) agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5)) # 验证并修正:确保不超过总agent数 if agents_per_hour_min > num_entities: logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) 超过总Agent数 ({num_entities}),已修正") agents_per_hour_min = max(1, num_entities // 10) if agents_per_hour_max > num_entities: logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) 超过总Agent数 ({num_entities}),已修正") agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2) # 确保 min < max if agents_per_hour_min >= agents_per_hour_max: agents_per_hour_min = max(1, agents_per_hour_max // 2) logger.warning(f"agents_per_hour_min >= max,已修正为 {agents_per_hour_min}") return TimeSimulationConfig( total_simulation_hours=result.get("total_simulation_hours", 72), minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时 agents_per_hour_min=agents_per_hour_min, agents_per_hour_max=agents_per_hour_max, peak_hours=result.get("peak_hours", [19, 20, 21, 22]), off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]), off_peak_activity_multiplier=0.05, # 凌晨几乎无人 morning_hours=result.get("morning_hours", [6, 7, 8]), morning_activity_multiplier=0.4, work_hours=result.get("work_hours", list(range(9, 19))), work_activity_multiplier=0.7, peak_activity_multiplier=1.5 ) def _generate_event_config( self, context: str, simulation_requirement: str, entities: List[EntityNode] ) -> Dict[str, Any]: """生成事件配置""" # 获取可用的实体类型列表,供 LLM 参考 entity_types_available = list(set( e.get_entity_type() or "Unknown" for e in entities )) # 为每种类型列出代表性实体名称 type_examples = {} for e in entities: etype = e.get_entity_type() or "Unknown" if etype not in type_examples: type_examples[etype] = [] if len(type_examples[etype]) < 3: type_examples[etype].append(e.name) type_info = "\n".join([ f"- {t}: {', '.join(examples)}" for t, examples in type_examples.items() ]) # 使用配置的上下文截断长度 context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH] prompt = f"""基于以下模拟需求,生成事件配置。 模拟需求: {simulation_requirement} {context_truncated} ## 可用实体类型及示例 {type_info} ## 任务 请生成事件配置JSON: - 提取热点话题关键词 - 描述舆论发展方向 - 设计初始帖子内容,**每个帖子必须指定 poster_type(发布者类型)** **重要**: poster_type 必须从上面的"可用实体类型"中选择,这样初始帖子才能分配给合适的 Agent 发布。 例如:官方声明应由 Official/University 类型发布,新闻由 MediaOutlet 发布,学生观点由 Student 发布。 返回JSON格式(不要markdown): {{ "hot_topics": ["关键词1", "关键词2", ...], "narrative_direction": "<舆论发展方向描述>", "initial_posts": [ {{"content": "帖子内容", "poster_type": "实体类型(必须从可用类型中选择)"}}, ... ], "reasoning": "<简要说明>" }}""" system_prompt = "你是舆论分析专家。返回纯JSON格式。注意 poster_type 必须精确匹配可用实体类型。" try: return self._call_llm_with_retry(prompt, system_prompt) except Exception as e: logger.warning(f"事件配置LLM生成失败: {e}, 使用默认配置") return { "hot_topics": [], "narrative_direction": "", "initial_posts": [], "reasoning": "使用默认配置" } def _parse_event_config(self, result: Dict[str, Any]) -> EventConfig: """解析事件配置结果""" return EventConfig( initial_posts=result.get("initial_posts", []), scheduled_events=[], hot_topics=result.get("hot_topics", []), narrative_direction=result.get("narrative_direction", "") ) def _assign_initial_post_agents( self, event_config: EventConfig, agent_configs: List[AgentActivityConfig] ) -> EventConfig: """ 为初始帖子分配合适的发布者 Agent 根据每个帖子的 poster_type 匹配最合适的 agent_id """ if not event_config.initial_posts: return event_config # 按实体类型建立 agent 索引 agents_by_type: Dict[str, List[AgentActivityConfig]] = {} for agent in agent_configs: etype = agent.entity_type.lower() if etype not in agents_by_type: agents_by_type[etype] = [] agents_by_type[etype].append(agent) # 类型映射表(处理 LLM 可能输出的不同格式) type_aliases = { "official": ["official", "university", "governmentagency", "government"], "university": ["university", "official"], "mediaoutlet": ["mediaoutlet", "media"], "student": ["student", "person"], "professor": ["professor", "expert", "teacher"], "alumni": ["alumni", "person"], "organization": ["organization", "ngo", "company", "group"], "person": ["person", "student", "alumni"], } # 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent used_indices: Dict[str, int] = {} updated_posts = [] for post in event_config.initial_posts: poster_type = post.get("poster_type", "").lower() content = post.get("content", "") # 尝试找到匹配的 agent matched_agent_id = None # 1. 直接匹配 if poster_type in agents_by_type: agents = agents_by_type[poster_type] idx = used_indices.get(poster_type, 0) % len(agents) matched_agent_id = agents[idx].agent_id used_indices[poster_type] = idx + 1 else: # 2. 使用别名匹配 for alias_key, aliases in type_aliases.items(): if poster_type in aliases or alias_key == poster_type: for alias in aliases: if alias in agents_by_type: agents = agents_by_type[alias] idx = used_indices.get(alias, 0) % len(agents) matched_agent_id = agents[idx].agent_id used_indices[alias] = idx + 1 break if matched_agent_id is not None: break # 3. 如果仍未找到,使用影响力最高的 agent if matched_agent_id is None: logger.warning(f"未找到类型 '{poster_type}' 的匹配 Agent,使用影响力最高的 Agent") if agent_configs: # 按影响力排序,选择影响力最高的 sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True) matched_agent_id = sorted_agents[0].agent_id else: matched_agent_id = 0 updated_posts.append({ "content": content, "poster_type": post.get("poster_type", "Unknown"), "poster_agent_id": matched_agent_id }) logger.info(f"初始帖子分配: poster_type='{poster_type}' -> agent_id={matched_agent_id}") event_config.initial_posts = updated_posts return event_config def _generate_agent_configs_batch( self, context: str, entities: List[EntityNode], start_idx: int, simulation_requirement: str ) -> List[AgentActivityConfig]: """分批生成Agent配置""" # 构建实体信息(使用配置的摘要长度) entity_list = [] summary_len = self.AGENT_SUMMARY_LENGTH for i, e in enumerate(entities): entity_list.append({ "agent_id": start_idx + i, "entity_name": e.name, "entity_type": e.get_entity_type() or "Unknown", "summary": e.summary[:summary_len] if e.summary else "" }) prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。 模拟需求: {simulation_requirement} ## 实体列表 ```json {json.dumps(entity_list, ensure_ascii=False, indent=2)} ``` ## 任务 为每个实体生成活动配置,注意: - **时间符合中国人作息**:凌晨0-5点几乎不活动,晚间19-22点最活跃 - **官方机构**(University/GovernmentAgency):活跃度低(0.1-0.3),工作时间(9-17)活动,响应慢(60-240分钟),影响力高(2.5-3.0) - **媒体**(MediaOutlet):活跃度中(0.4-0.6),全天活动(8-23),响应快(5-30分钟),影响力高(2.0-2.5) - **个人**(Student/Person/Alumni):活跃度高(0.6-0.9),主要晚间活动(18-23),响应快(1-15分钟),影响力低(0.8-1.2) - **公众人物/专家**:活跃度中(0.4-0.6),影响力中高(1.5-2.0) 返回JSON格式(不要markdown): {{ "agent_configs": [ {{ "agent_id": <必须与输入一致>, "activity_level": <0.0-1.0>, "posts_per_hour": <发帖频率>, "comments_per_hour": <评论频率>, "active_hours": [<活跃小时列表,考虑中国人作息>], "response_delay_min": <最小响应延迟分钟>, "response_delay_max": <最大响应延迟分钟>, "sentiment_bias": <-1.0到1.0>, "stance": "", "influence_weight": <影响力权重> }}, ... ] }}""" system_prompt = "你是社交媒体行为分析专家。返回纯JSON,配置需符合中国人作息习惯。" try: result = self._call_llm_with_retry(prompt, system_prompt) llm_configs = {cfg["agent_id"]: cfg for cfg in result.get("agent_configs", [])} except Exception as e: logger.warning(f"Agent配置批次LLM生成失败: {e}, 使用规则生成") llm_configs = {} # 构建AgentActivityConfig对象 configs = [] for i, entity in enumerate(entities): agent_id = start_idx + i cfg = llm_configs.get(agent_id, {}) # 如果LLM没有生成,使用规则生成 if not cfg: cfg = self._generate_agent_config_by_rule(entity) config = AgentActivityConfig( agent_id=agent_id, entity_uuid=entity.uuid, entity_name=entity.name, entity_type=entity.get_entity_type() or "Unknown", activity_level=cfg.get("activity_level", 0.5), posts_per_hour=cfg.get("posts_per_hour", 0.5), comments_per_hour=cfg.get("comments_per_hour", 1.0), active_hours=cfg.get("active_hours", list(range(9, 23))), response_delay_min=cfg.get("response_delay_min", 5), response_delay_max=cfg.get("response_delay_max", 60), sentiment_bias=cfg.get("sentiment_bias", 0.0), stance=cfg.get("stance", "neutral"), influence_weight=cfg.get("influence_weight", 1.0) ) configs.append(config) return configs def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: """基于规则生成单个Agent配置(中国人作息)""" entity_type = (entity.get_entity_type() or "Unknown").lower() if entity_type in ["university", "governmentagency", "ngo"]: # 官方机构:工作时间活动,低频率,高影响力 return { "activity_level": 0.2, "posts_per_hour": 0.1, "comments_per_hour": 0.05, "active_hours": list(range(9, 18)), # 9:00-17:59 "response_delay_min": 60, "response_delay_max": 240, "sentiment_bias": 0.0, "stance": "neutral", "influence_weight": 3.0 } elif entity_type in ["mediaoutlet"]: # 媒体:全天活动,中等频率,高影响力 return { "activity_level": 0.5, "posts_per_hour": 0.8, "comments_per_hour": 0.3, "active_hours": list(range(7, 24)), # 7:00-23:59 "response_delay_min": 5, "response_delay_max": 30, "sentiment_bias": 0.0, "stance": "observer", "influence_weight": 2.5 } elif entity_type in ["professor", "expert", "official"]: # 专家/教授:工作+晚间活动,中等频率 return { "activity_level": 0.4, "posts_per_hour": 0.3, "comments_per_hour": 0.5, "active_hours": list(range(8, 22)), # 8:00-21:59 "response_delay_min": 15, "response_delay_max": 90, "sentiment_bias": 0.0, "stance": "neutral", "influence_weight": 2.0 } elif entity_type in ["student"]: # 学生:晚间为主,高频率 return { "activity_level": 0.8, "posts_per_hour": 0.6, "comments_per_hour": 1.5, "active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 上午+晚间 "response_delay_min": 1, "response_delay_max": 15, "sentiment_bias": 0.0, "stance": "neutral", "influence_weight": 0.8 } elif entity_type in ["alumni"]: # 校友:晚间为主 return { "activity_level": 0.6, "posts_per_hour": 0.4, "comments_per_hour": 0.8, "active_hours": [12, 13, 19, 20, 21, 22, 23], # 午休+晚间 "response_delay_min": 5, "response_delay_max": 30, "sentiment_bias": 0.0, "stance": "neutral", "influence_weight": 1.0 } else: # 普通人:晚间高峰 return { "activity_level": 0.7, "posts_per_hour": 0.5, "comments_per_hour": 1.2, "active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 白天+晚间 "response_delay_min": 2, "response_delay_max": 20, "sentiment_bias": 0.0, "stance": "neutral", "influence_weight": 1.0 } ================================================ FILE: backend/app/services/simulation_ipc.py ================================================ """ 模拟IPC通信模块 用于Flask后端和模拟脚本之间的进程间通信 通过文件系统实现简单的命令/响应模式: 1. Flask写入命令到 commands/ 目录 2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录 3. Flask轮询响应目录获取结果 """ import os import json import time import uuid from typing import Dict, Any, Optional, List from dataclasses import dataclass, field from datetime import datetime from enum import Enum from ..utils.logger import get_logger logger = get_logger('mirofish.simulation_ipc') class CommandType(str, Enum): """命令类型""" INTERVIEW = "interview" # 单个Agent采访 BATCH_INTERVIEW = "batch_interview" # 批量采访 CLOSE_ENV = "close_env" # 关闭环境 class CommandStatus(str, Enum): """命令状态""" PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" @dataclass class IPCCommand: """IPC命令""" command_id: str command_type: CommandType args: Dict[str, Any] timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> Dict[str, Any]: return { "command_id": self.command_id, "command_type": self.command_type.value, "args": self.args, "timestamp": self.timestamp } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand': return cls( command_id=data["command_id"], command_type=CommandType(data["command_type"]), args=data.get("args", {}), timestamp=data.get("timestamp", datetime.now().isoformat()) ) @dataclass class IPCResponse: """IPC响应""" command_id: str status: CommandStatus result: Optional[Dict[str, Any]] = None error: Optional[str] = None timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> Dict[str, Any]: return { "command_id": self.command_id, "status": self.status.value, "result": self.result, "error": self.error, "timestamp": self.timestamp } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse': return cls( command_id=data["command_id"], status=CommandStatus(data["status"]), result=data.get("result"), error=data.get("error"), timestamp=data.get("timestamp", datetime.now().isoformat()) ) class SimulationIPCClient: """ 模拟IPC客户端(Flask端使用) 用于向模拟进程发送命令并等待响应 """ def __init__(self, simulation_dir: str): """ 初始化IPC客户端 Args: simulation_dir: 模拟数据目录 """ self.simulation_dir = simulation_dir self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.responses_dir = os.path.join(simulation_dir, "ipc_responses") # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def send_command( self, command_type: CommandType, args: Dict[str, Any], timeout: float = 60.0, poll_interval: float = 0.5 ) -> IPCResponse: """ 发送命令并等待响应 Args: command_type: 命令类型 args: 命令参数 timeout: 超时时间(秒) poll_interval: 轮询间隔(秒) Returns: IPCResponse Raises: TimeoutError: 等待响应超时 """ command_id = str(uuid.uuid4()) command = IPCCommand( command_id=command_id, command_type=command_type, args=args ) # 写入命令文件 command_file = os.path.join(self.commands_dir, f"{command_id}.json") with open(command_file, 'w', encoding='utf-8') as f: json.dump(command.to_dict(), f, ensure_ascii=False, indent=2) logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}") # 等待响应 response_file = os.path.join(self.responses_dir, f"{command_id}.json") start_time = time.time() while time.time() - start_time < timeout: if os.path.exists(response_file): try: with open(response_file, 'r', encoding='utf-8') as f: response_data = json.load(f) response = IPCResponse.from_dict(response_data) # 清理命令和响应文件 try: os.remove(command_file) os.remove(response_file) except OSError: pass logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}") return response except (json.JSONDecodeError, KeyError) as e: logger.warning(f"解析响应失败: {e}") time.sleep(poll_interval) # 超时 logger.error(f"等待IPC响应超时: command_id={command_id}") # 清理命令文件 try: os.remove(command_file) except OSError: pass raise TimeoutError(f"等待命令响应超时 ({timeout}秒)") def send_interview( self, agent_id: int, prompt: str, platform: str = None, timeout: float = 60.0 ) -> IPCResponse: """ 发送单个Agent采访命令 Args: agent_id: Agent ID prompt: 采访问题 platform: 指定平台(可选) - "twitter": 只采访Twitter平台 - "reddit": 只采访Reddit平台 - None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台 timeout: 超时时间 Returns: IPCResponse,result字段包含采访结果 """ args = { "agent_id": agent_id, "prompt": prompt } if platform: args["platform"] = platform return self.send_command( command_type=CommandType.INTERVIEW, args=args, timeout=timeout ) def send_batch_interview( self, interviews: List[Dict[str, Any]], platform: str = None, timeout: float = 120.0 ) -> IPCResponse: """ 发送批量采访命令 Args: interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} platform: 默认平台(可选,会被每个采访项的platform覆盖) - "twitter": 默认只采访Twitter平台 - "reddit": 默认只采访Reddit平台 - None: 双平台模拟时每个Agent同时采访两个平台 timeout: 超时时间 Returns: IPCResponse,result字段包含所有采访结果 """ args = {"interviews": interviews} if platform: args["platform"] = platform return self.send_command( command_type=CommandType.BATCH_INTERVIEW, args=args, timeout=timeout ) def send_close_env(self, timeout: float = 30.0) -> IPCResponse: """ 发送关闭环境命令 Args: timeout: 超时时间 Returns: IPCResponse """ return self.send_command( command_type=CommandType.CLOSE_ENV, args={}, timeout=timeout ) def check_env_alive(self) -> bool: """ 检查模拟环境是否存活 通过检查 env_status.json 文件来判断 """ status_file = os.path.join(self.simulation_dir, "env_status.json") if not os.path.exists(status_file): return False try: with open(status_file, 'r', encoding='utf-8') as f: status = json.load(f) return status.get("status") == "alive" except (json.JSONDecodeError, OSError): return False class SimulationIPCServer: """ 模拟IPC服务器(模拟脚本端使用) 轮询命令目录,执行命令并返回响应 """ def __init__(self, simulation_dir: str): """ 初始化IPC服务器 Args: simulation_dir: 模拟数据目录 """ self.simulation_dir = simulation_dir self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.responses_dir = os.path.join(simulation_dir, "ipc_responses") # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) # 环境状态 self._running = False def start(self): """标记服务器为运行状态""" self._running = True self._update_env_status("alive") def stop(self): """标记服务器为停止状态""" self._running = False self._update_env_status("stopped") def _update_env_status(self, status: str): """更新环境状态文件""" status_file = os.path.join(self.simulation_dir, "env_status.json") with open(status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) def poll_commands(self) -> Optional[IPCCommand]: """ 轮询命令目录,返回第一个待处理的命令 Returns: IPCCommand 或 None """ if not os.path.exists(self.commands_dir): return None # 按时间排序获取命令文件 command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) command_files.sort(key=lambda x: x[1]) for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) return IPCCommand.from_dict(data) except (json.JSONDecodeError, KeyError, OSError) as e: logger.warning(f"读取命令文件失败: {filepath}, {e}") continue return None def send_response(self, response: IPCResponse): """ 发送响应 Args: response: IPC响应 """ response_file = os.path.join(self.responses_dir, f"{response.command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response.to_dict(), f, ensure_ascii=False, indent=2) # 删除命令文件 command_file = os.path.join(self.commands_dir, f"{response.command_id}.json") try: os.remove(command_file) except OSError: pass def send_success(self, command_id: str, result: Dict[str, Any]): """发送成功响应""" self.send_response(IPCResponse( command_id=command_id, status=CommandStatus.COMPLETED, result=result )) def send_error(self, command_id: str, error: str): """发送错误响应""" self.send_response(IPCResponse( command_id=command_id, status=CommandStatus.FAILED, error=error )) ================================================ FILE: backend/app/services/simulation_manager.py ================================================ """ OASIS模拟管理器 管理Twitter和Reddit双平台并行模拟 使用预设脚本 + LLM智能生成配置参数 """ import os import json import shutil from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime from enum import Enum from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import ZepEntityReader, FilteredEntities from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters logger = get_logger('mirofish.simulation') class SimulationStatus(str, Enum): """模拟状态""" CREATED = "created" PREPARING = "preparing" READY = "ready" RUNNING = "running" PAUSED = "paused" STOPPED = "stopped" # 模拟被手动停止 COMPLETED = "completed" # 模拟自然完成 FAILED = "failed" class PlatformType(str, Enum): """平台类型""" TWITTER = "twitter" REDDIT = "reddit" @dataclass class SimulationState: """模拟状态""" simulation_id: str project_id: str graph_id: str # 平台启用状态 enable_twitter: bool = True enable_reddit: bool = True # 状态 status: SimulationStatus = SimulationStatus.CREATED # 准备阶段数据 entities_count: int = 0 profiles_count: int = 0 entity_types: List[str] = field(default_factory=list) # 配置生成信息 config_generated: bool = False config_reasoning: str = "" # 运行时数据 current_round: int = 0 twitter_status: str = "not_started" reddit_status: str = "not_started" # 时间戳 created_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) # 错误信息 error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """完整状态字典(内部使用)""" return { "simulation_id": self.simulation_id, "project_id": self.project_id, "graph_id": self.graph_id, "enable_twitter": self.enable_twitter, "enable_reddit": self.enable_reddit, "status": self.status.value, "entities_count": self.entities_count, "profiles_count": self.profiles_count, "entity_types": self.entity_types, "config_generated": self.config_generated, "config_reasoning": self.config_reasoning, "current_round": self.current_round, "twitter_status": self.twitter_status, "reddit_status": self.reddit_status, "created_at": self.created_at, "updated_at": self.updated_at, "error": self.error, } def to_simple_dict(self) -> Dict[str, Any]: """简化状态字典(API返回使用)""" return { "simulation_id": self.simulation_id, "project_id": self.project_id, "graph_id": self.graph_id, "status": self.status.value, "entities_count": self.entities_count, "profiles_count": self.profiles_count, "entity_types": self.entity_types, "config_generated": self.config_generated, "error": self.error, } class SimulationManager: """ 模拟管理器 核心功能: 1. 从Zep图谱读取实体并过滤 2. 生成OASIS Agent Profile 3. 使用LLM智能生成模拟配置参数 4. 准备预设脚本所需的所有文件 """ # 模拟数据存储目录 SIMULATION_DATA_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) def __init__(self): # 确保目录存在 os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) # 内存中的模拟状态缓存 self._simulations: Dict[str, SimulationState] = {} def _get_simulation_dir(self, simulation_id: str) -> str: """获取模拟数据目录""" sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id) os.makedirs(sim_dir, exist_ok=True) return sim_dir def _save_simulation_state(self, state: SimulationState): """保存模拟状态到文件""" sim_dir = self._get_simulation_dir(state.simulation_id) state_file = os.path.join(sim_dir, "state.json") state.updated_at = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state.to_dict(), f, ensure_ascii=False, indent=2) self._simulations[state.simulation_id] = state def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]: """从文件加载模拟状态""" if simulation_id in self._simulations: return self._simulations[simulation_id] sim_dir = self._get_simulation_dir(simulation_id) state_file = os.path.join(sim_dir, "state.json") if not os.path.exists(state_file): return None with open(state_file, 'r', encoding='utf-8') as f: data = json.load(f) state = SimulationState( simulation_id=simulation_id, project_id=data.get("project_id", ""), graph_id=data.get("graph_id", ""), enable_twitter=data.get("enable_twitter", True), enable_reddit=data.get("enable_reddit", True), status=SimulationStatus(data.get("status", "created")), entities_count=data.get("entities_count", 0), profiles_count=data.get("profiles_count", 0), entity_types=data.get("entity_types", []), config_generated=data.get("config_generated", False), config_reasoning=data.get("config_reasoning", ""), current_round=data.get("current_round", 0), twitter_status=data.get("twitter_status", "not_started"), reddit_status=data.get("reddit_status", "not_started"), created_at=data.get("created_at", datetime.now().isoformat()), updated_at=data.get("updated_at", datetime.now().isoformat()), error=data.get("error"), ) self._simulations[simulation_id] = state return state def create_simulation( self, project_id: str, graph_id: str, enable_twitter: bool = True, enable_reddit: bool = True, ) -> SimulationState: """ 创建新的模拟 Args: project_id: 项目ID graph_id: Zep图谱ID enable_twitter: 是否启用Twitter模拟 enable_reddit: 是否启用Reddit模拟 Returns: SimulationState """ import uuid simulation_id = f"sim_{uuid.uuid4().hex[:12]}" state = SimulationState( simulation_id=simulation_id, project_id=project_id, graph_id=graph_id, enable_twitter=enable_twitter, enable_reddit=enable_reddit, status=SimulationStatus.CREATED, ) self._save_simulation_state(state) logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}") return state def prepare_simulation( self, simulation_id: str, simulation_requirement: str, document_text: str, defined_entity_types: Optional[List[str]] = None, use_llm_for_profiles: bool = True, progress_callback: Optional[callable] = None, parallel_profile_count: int = 3 ) -> SimulationState: """ 准备模拟环境(全程自动化) 步骤: 1. 从Zep图谱读取并过滤实体 2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行) 3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等) 4. 保存配置文件和Profile文件 5. 复制预设脚本到模拟目录 Args: simulation_id: 模拟ID simulation_requirement: 模拟需求描述(用于LLM生成配置) document_text: 原始文档内容(用于LLM理解背景) defined_entity_types: 预定义的实体类型(可选) use_llm_for_profiles: 是否使用LLM生成详细人设 progress_callback: 进度回调函数 (stage, progress, message) parallel_profile_count: 并行生成人设的数量,默认3 Returns: SimulationState """ state = self._load_simulation_state(simulation_id) if not state: raise ValueError(f"模拟不存在: {simulation_id}") try: state.status = SimulationStatus.PREPARING self._save_simulation_state(state) sim_dir = self._get_simulation_dir(simulation_id) # ========== 阶段1: 读取并过滤实体 ========== if progress_callback: progress_callback("reading", 0, "正在连接Zep图谱...") reader = ZepEntityReader() if progress_callback: progress_callback("reading", 30, "正在读取节点数据...") filtered = reader.filter_defined_entities( graph_id=state.graph_id, defined_entity_types=defined_entity_types, enrich_with_edges=True ) state.entities_count = filtered.filtered_count state.entity_types = list(filtered.entity_types) if progress_callback: progress_callback( "reading", 100, f"完成,共 {filtered.filtered_count} 个实体", current=filtered.filtered_count, total=filtered.filtered_count ) if filtered.filtered_count == 0: state.status = SimulationStatus.FAILED state.error = "没有找到符合条件的实体,请检查图谱是否正确构建" self._save_simulation_state(state) return state # ========== 阶段2: 生成Agent Profile ========== total_entities = len(filtered.entities) if progress_callback: progress_callback( "generating_profiles", 0, "开始生成...", current=0, total=total_entities ) # 传入graph_id以启用Zep检索功能,获取更丰富的上下文 generator = OasisProfileGenerator(graph_id=state.graph_id) def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_profiles", int(current / total * 100), msg, current=current, total=total, item_name=msg ) # 设置实时保存的文件路径(优先使用 Reddit JSON 格式) realtime_output_path = None realtime_platform = "reddit" if state.enable_reddit: realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json") realtime_platform = "reddit" elif state.enable_twitter: realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv") realtime_platform = "twitter" profiles = generator.generate_profiles_from_entities( entities=filtered.entities, use_llm=use_llm_for_profiles, progress_callback=profile_progress, graph_id=state.graph_id, # 传入graph_id用于Zep检索 parallel_count=parallel_profile_count, # 并行生成数量 realtime_output_path=realtime_output_path, # 实时保存路径 output_platform=realtime_platform # 输出格式 ) state.profiles_count = len(profiles) # 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式) # Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性 if progress_callback: progress_callback( "generating_profiles", 95, "保存Profile文件...", current=total_entities, total=total_entities ) if state.enable_reddit: generator.save_profiles( profiles=profiles, file_path=os.path.join(sim_dir, "reddit_profiles.json"), platform="reddit" ) if state.enable_twitter: # Twitter使用CSV格式!这是OASIS的要求 generator.save_profiles( profiles=profiles, file_path=os.path.join(sim_dir, "twitter_profiles.csv"), platform="twitter" ) if progress_callback: progress_callback( "generating_profiles", 100, f"完成,共 {len(profiles)} 个Profile", current=len(profiles), total=len(profiles) ) # ========== 阶段3: LLM智能生成模拟配置 ========== if progress_callback: progress_callback( "generating_config", 0, "正在分析模拟需求...", current=0, total=3 ) config_generator = SimulationConfigGenerator() if progress_callback: progress_callback( "generating_config", 30, "正在调用LLM生成配置...", current=1, total=3 ) sim_params = config_generator.generate_config( simulation_id=simulation_id, project_id=state.project_id, graph_id=state.graph_id, simulation_requirement=simulation_requirement, document_text=document_text, entities=filtered.entities, enable_twitter=state.enable_twitter, enable_reddit=state.enable_reddit ) if progress_callback: progress_callback( "generating_config", 70, "正在保存配置文件...", current=2, total=3 ) # 保存配置文件 config_path = os.path.join(sim_dir, "simulation_config.json") with open(config_path, 'w', encoding='utf-8') as f: f.write(sim_params.to_json()) state.config_generated = True state.config_reasoning = sim_params.generation_reasoning if progress_callback: progress_callback( "generating_config", 100, "配置生成完成", current=3, total=3 ) # 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录 # 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本 # 更新状态 state.status = SimulationStatus.READY self._save_simulation_state(state) logger.info(f"模拟准备完成: {simulation_id}, " f"entities={state.entities_count}, profiles={state.profiles_count}") return state except Exception as e: logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}") import traceback logger.error(traceback.format_exc()) state.status = SimulationStatus.FAILED state.error = str(e) self._save_simulation_state(state) raise def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: """获取模拟状态""" return self._load_simulation_state(simulation_id) def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]: """列出所有模拟""" simulations = [] if os.path.exists(self.SIMULATION_DATA_DIR): for sim_id in os.listdir(self.SIMULATION_DATA_DIR): # 跳过隐藏文件(如 .DS_Store)和非目录文件 sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id) if sim_id.startswith('.') or not os.path.isdir(sim_path): continue state = self._load_simulation_state(sim_id) if state: if project_id is None or state.project_id == project_id: simulations.append(state) return simulations def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]: """获取模拟的Agent Profile""" state = self._load_simulation_state(simulation_id) if not state: raise ValueError(f"模拟不存在: {simulation_id}") sim_dir = self._get_simulation_dir(simulation_id) profile_path = os.path.join(sim_dir, f"{platform}_profiles.json") if not os.path.exists(profile_path): return [] with open(profile_path, 'r', encoding='utf-8') as f: return json.load(f) def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: """获取模拟配置""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): return None with open(config_path, 'r', encoding='utf-8') as f: return json.load(f) def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: """获取运行说明""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) return { "simulation_dir": sim_dir, "scripts_dir": scripts_dir, "config_file": config_path, "commands": { "twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}", "reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}", "parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}", }, "instructions": ( f"1. 激活conda环境: conda activate MiroFish\n" f"2. 运行模拟 (脚本位于 {scripts_dir}):\n" f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" ) } ================================================ FILE: backend/app/services/simulation_runner.py ================================================ """ OASIS模拟运行器 在后台运行模拟并记录每个Agent的动作,支持实时状态监控 """ import os import sys import json import time import asyncio import threading import subprocess import signal import atexit from typing import Dict, Any, List, Optional, Union from dataclasses import dataclass, field from datetime import datetime from enum import Enum from queue import Queue from ..config import Config from ..utils.logger import get_logger from .zep_graph_memory_updater import ZepGraphMemoryManager from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse logger = get_logger('mirofish.simulation_runner') # 标记是否已注册清理函数 _cleanup_registered = False # 平台检测 IS_WINDOWS = sys.platform == 'win32' class RunnerStatus(str, Enum): """运行器状态""" IDLE = "idle" STARTING = "starting" RUNNING = "running" PAUSED = "paused" STOPPING = "stopping" STOPPED = "stopped" COMPLETED = "completed" FAILED = "failed" @dataclass class AgentAction: """Agent动作记录""" round_num: int timestamp: str platform: str # twitter / reddit agent_id: int agent_name: str action_type: str # CREATE_POST, LIKE_POST, etc. action_args: Dict[str, Any] = field(default_factory=dict) result: Optional[str] = None success: bool = True def to_dict(self) -> Dict[str, Any]: return { "round_num": self.round_num, "timestamp": self.timestamp, "platform": self.platform, "agent_id": self.agent_id, "agent_name": self.agent_name, "action_type": self.action_type, "action_args": self.action_args, "result": self.result, "success": self.success, } @dataclass class RoundSummary: """每轮摘要""" round_num: int start_time: str end_time: Optional[str] = None simulated_hour: int = 0 twitter_actions: int = 0 reddit_actions: int = 0 active_agents: List[int] = field(default_factory=list) actions: List[AgentAction] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "round_num": self.round_num, "start_time": self.start_time, "end_time": self.end_time, "simulated_hour": self.simulated_hour, "twitter_actions": self.twitter_actions, "reddit_actions": self.reddit_actions, "active_agents": self.active_agents, "actions_count": len(self.actions), "actions": [a.to_dict() for a in self.actions], } @dataclass class SimulationRunState: """模拟运行状态(实时)""" simulation_id: str runner_status: RunnerStatus = RunnerStatus.IDLE # 进度信息 current_round: int = 0 total_rounds: int = 0 simulated_hours: int = 0 total_simulation_hours: int = 0 # 各平台独立轮次和模拟时间(用于双平台并行显示) twitter_current_round: int = 0 reddit_current_round: int = 0 twitter_simulated_hours: int = 0 reddit_simulated_hours: int = 0 # 平台状态 twitter_running: bool = False reddit_running: bool = False twitter_actions_count: int = 0 reddit_actions_count: int = 0 # 平台完成状态(通过检测 actions.jsonl 中的 simulation_end 事件) twitter_completed: bool = False reddit_completed: bool = False # 每轮摘要 rounds: List[RoundSummary] = field(default_factory=list) # 最近动作(用于前端实时展示) recent_actions: List[AgentAction] = field(default_factory=list) max_recent_actions: int = 50 # 时间戳 started_at: Optional[str] = None updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) completed_at: Optional[str] = None # 错误信息 error: Optional[str] = None # 进程ID(用于停止) process_pid: Optional[int] = None def add_action(self, action: AgentAction): """添加动作到最近动作列表""" self.recent_actions.insert(0, action) if len(self.recent_actions) > self.max_recent_actions: self.recent_actions = self.recent_actions[:self.max_recent_actions] if action.platform == "twitter": self.twitter_actions_count += 1 else: self.reddit_actions_count += 1 self.updated_at = datetime.now().isoformat() def to_dict(self) -> Dict[str, Any]: return { "simulation_id": self.simulation_id, "runner_status": self.runner_status.value, "current_round": self.current_round, "total_rounds": self.total_rounds, "simulated_hours": self.simulated_hours, "total_simulation_hours": self.total_simulation_hours, "progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1), # 各平台独立轮次和时间 "twitter_current_round": self.twitter_current_round, "reddit_current_round": self.reddit_current_round, "twitter_simulated_hours": self.twitter_simulated_hours, "reddit_simulated_hours": self.reddit_simulated_hours, "twitter_running": self.twitter_running, "reddit_running": self.reddit_running, "twitter_completed": self.twitter_completed, "reddit_completed": self.reddit_completed, "twitter_actions_count": self.twitter_actions_count, "reddit_actions_count": self.reddit_actions_count, "total_actions_count": self.twitter_actions_count + self.reddit_actions_count, "started_at": self.started_at, "updated_at": self.updated_at, "completed_at": self.completed_at, "error": self.error, "process_pid": self.process_pid, } def to_detail_dict(self) -> Dict[str, Any]: """包含最近动作的详细信息""" result = self.to_dict() result["recent_actions"] = [a.to_dict() for a in self.recent_actions] result["rounds_count"] = len(self.rounds) return result class SimulationRunner: """ 模拟运行器 负责: 1. 在后台进程中运行OASIS模拟 2. 解析运行日志,记录每个Agent的动作 3. 提供实时状态查询接口 4. 支持暂停/停止/恢复操作 """ # 运行状态存储目录 RUN_STATE_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) # 脚本目录 SCRIPTS_DIR = os.path.join( os.path.dirname(__file__), '../../scripts' ) # 内存中的运行状态 _run_states: Dict[str, SimulationRunState] = {} _processes: Dict[str, subprocess.Popen] = {} _action_queues: Dict[str, Queue] = {} _monitor_threads: Dict[str, threading.Thread] = {} _stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄 _stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄 # 图谱记忆更新配置 _graph_memory_enabled: Dict[str, bool] = {} # simulation_id -> enabled @classmethod def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: """获取运行状态""" if simulation_id in cls._run_states: return cls._run_states[simulation_id] # 尝试从文件加载 state = cls._load_run_state(simulation_id) if state: cls._run_states[simulation_id] = state return state @classmethod def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: """从文件加载运行状态""" state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json") if not os.path.exists(state_file): return None try: with open(state_file, 'r', encoding='utf-8') as f: data = json.load(f) state = SimulationRunState( simulation_id=simulation_id, runner_status=RunnerStatus(data.get("runner_status", "idle")), current_round=data.get("current_round", 0), total_rounds=data.get("total_rounds", 0), simulated_hours=data.get("simulated_hours", 0), total_simulation_hours=data.get("total_simulation_hours", 0), # 各平台独立轮次和时间 twitter_current_round=data.get("twitter_current_round", 0), reddit_current_round=data.get("reddit_current_round", 0), twitter_simulated_hours=data.get("twitter_simulated_hours", 0), reddit_simulated_hours=data.get("reddit_simulated_hours", 0), twitter_running=data.get("twitter_running", False), reddit_running=data.get("reddit_running", False), twitter_completed=data.get("twitter_completed", False), reddit_completed=data.get("reddit_completed", False), twitter_actions_count=data.get("twitter_actions_count", 0), reddit_actions_count=data.get("reddit_actions_count", 0), started_at=data.get("started_at"), updated_at=data.get("updated_at", datetime.now().isoformat()), completed_at=data.get("completed_at"), error=data.get("error"), process_pid=data.get("process_pid"), ) # 加载最近动作 actions_data = data.get("recent_actions", []) for a in actions_data: state.recent_actions.append(AgentAction( round_num=a.get("round_num", 0), timestamp=a.get("timestamp", ""), platform=a.get("platform", ""), agent_id=a.get("agent_id", 0), agent_name=a.get("agent_name", ""), action_type=a.get("action_type", ""), action_args=a.get("action_args", {}), result=a.get("result"), success=a.get("success", True), )) return state except Exception as e: logger.error(f"加载运行状态失败: {str(e)}") return None @classmethod def _save_run_state(cls, state: SimulationRunState): """保存运行状态到文件""" sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) os.makedirs(sim_dir, exist_ok=True) state_file = os.path.join(sim_dir, "run_state.json") data = state.to_detail_dict() with open(state_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) cls._run_states[state.simulation_id] = state @classmethod def start_simulation( cls, simulation_id: str, platform: str = "parallel", # twitter / reddit / parallel max_rounds: int = None, # 最大模拟轮数(可选,用于截断过长的模拟) enable_graph_memory_update: bool = False, # 是否将活动更新到Zep图谱 graph_id: str = None # Zep图谱ID(启用图谱更新时必需) ) -> SimulationRunState: """ 启动模拟 Args: simulation_id: 模拟ID platform: 运行平台 (twitter/reddit/parallel) max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) enable_graph_memory_update: 是否将Agent活动动态更新到Zep图谱 graph_id: Zep图谱ID(启用图谱更新时必需) Returns: SimulationRunState """ # 检查是否已在运行 existing = cls.get_run_state(simulation_id) if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]: raise ValueError(f"模拟已在运行中: {simulation_id}") # 加载模拟配置 sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # 初始化运行状态 time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = int(total_hours * 60 / minutes_per_round) # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") state = SimulationRunState( simulation_id=simulation_id, runner_status=RunnerStatus.STARTING, total_rounds=total_rounds, total_simulation_hours=total_hours, started_at=datetime.now().isoformat(), ) cls._save_run_state(state) # 如果启用图谱记忆更新,创建更新器 if enable_graph_memory_update: if not graph_id: raise ValueError("启用图谱记忆更新时必须提供 graph_id") try: ZepGraphMemoryManager.create_updater(simulation_id, graph_id) cls._graph_memory_enabled[simulation_id] = True logger.info(f"已启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") except Exception as e: logger.error(f"创建图谱记忆更新器失败: {e}") cls._graph_memory_enabled[simulation_id] = False else: cls._graph_memory_enabled[simulation_id] = False # 确定运行哪个脚本(脚本位于 backend/scripts/ 目录) if platform == "twitter": script_name = "run_twitter_simulation.py" state.twitter_running = True elif platform == "reddit": script_name = "run_reddit_simulation.py" state.reddit_running = True else: script_name = "run_parallel_simulation.py" state.twitter_running = True state.reddit_running = True script_path = os.path.join(cls.SCRIPTS_DIR, script_name) if not os.path.exists(script_path): raise ValueError(f"脚本不存在: {script_path}") # 创建动作队列 action_queue = Queue() cls._action_queues[simulation_id] = action_queue # 启动模拟进程 try: # 构建运行命令,使用完整路径 # 新的日志结构: # twitter/actions.jsonl - Twitter 动作日志 # reddit/actions.jsonl - Reddit 动作日志 # simulation.log - 主进程日志 cmd = [ sys.executable, # Python解释器 script_path, "--config", config_path, # 使用完整配置文件路径 ] # 如果指定了最大轮数,添加到命令行参数 if max_rounds is not None and max_rounds > 0: cmd.extend(["--max-rounds", str(max_rounds)]) # 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞 main_log_path = os.path.join(sim_dir, "simulation.log") main_log_file = open(main_log_path, 'w', encoding='utf-8') # 设置子进程环境变量,确保 Windows 上使用 UTF-8 编码 # 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题 env = os.environ.copy() env['PYTHONUTF8'] = '1' # Python 3.7+ 支持,让所有 open() 默认使用 UTF-8 env['PYTHONIOENCODING'] = 'utf-8' # 确保 stdout/stderr 使用 UTF-8 # 设置工作目录为模拟目录(数据库等文件会生成在此) # 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程 process = subprocess.Popen( cmd, cwd=sim_dir, stdout=main_log_file, stderr=subprocess.STDOUT, # stderr 也写入同一个文件 text=True, encoding='utf-8', # 显式指定编码 bufsize=1, env=env, # 传递带有 UTF-8 设置的环境变量 start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程 ) # 保存文件句柄以便后续关闭 cls._stdout_files[simulation_id] = main_log_file cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr state.process_pid = process.pid state.runner_status = RunnerStatus.RUNNING cls._processes[simulation_id] = process cls._save_run_state(state) # 启动监控线程 monitor_thread = threading.Thread( target=cls._monitor_simulation, args=(simulation_id,), daemon=True ) monitor_thread.start() cls._monitor_threads[simulation_id] = monitor_thread logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}") except Exception as e: state.runner_status = RunnerStatus.FAILED state.error = str(e) cls._save_run_state(state) raise return state @classmethod def _monitor_simulation(cls, simulation_id: str): """监控模拟进程,解析动作日志""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) # 新的日志结构:分平台的动作日志 twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") process = cls._processes.get(simulation_id) state = cls.get_run_state(simulation_id) if not process or not state: return twitter_position = 0 reddit_position = 0 try: while process.poll() is None: # 进程仍在运行 # 读取 Twitter 动作日志 if os.path.exists(twitter_actions_log): twitter_position = cls._read_action_log( twitter_actions_log, twitter_position, state, "twitter" ) # 读取 Reddit 动作日志 if os.path.exists(reddit_actions_log): reddit_position = cls._read_action_log( reddit_actions_log, reddit_position, state, "reddit" ) # 更新状态 cls._save_run_state(state) time.sleep(2) # 进程结束后,最后读取一次日志 if os.path.exists(twitter_actions_log): cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter") if os.path.exists(reddit_actions_log): cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit") # 进程结束 exit_code = process.returncode if exit_code == 0: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() logger.info(f"模拟完成: {simulation_id}") else: state.runner_status = RunnerStatus.FAILED # 从主日志文件读取错误信息 main_log_path = os.path.join(sim_dir, "simulation.log") error_info = "" try: if os.path.exists(main_log_path): with open(main_log_path, 'r', encoding='utf-8') as f: error_info = f.read()[-2000:] # 取最后2000字符 except Exception: pass state.error = f"进程退出码: {exit_code}, 错误: {error_info}" logger.error(f"模拟失败: {simulation_id}, error={state.error}") state.twitter_running = False state.reddit_running = False cls._save_run_state(state) except Exception as e: logger.error(f"监控线程异常: {simulation_id}, error={str(e)}") state.runner_status = RunnerStatus.FAILED state.error = str(e) cls._save_run_state(state) finally: # 停止图谱记忆更新器 if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") except Exception as e: logger.error(f"停止图谱记忆更新器失败: {e}") cls._graph_memory_enabled.pop(simulation_id, None) # 清理进程资源 cls._processes.pop(simulation_id, None) cls._action_queues.pop(simulation_id, None) # 关闭日志文件句柄 if simulation_id in cls._stdout_files: try: cls._stdout_files[simulation_id].close() except Exception: pass cls._stdout_files.pop(simulation_id, None) if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]: try: cls._stderr_files[simulation_id].close() except Exception: pass cls._stderr_files.pop(simulation_id, None) @classmethod def _read_action_log( cls, log_path: str, position: int, state: SimulationRunState, platform: str ) -> int: """ 读取动作日志文件 Args: log_path: 日志文件路径 position: 上次读取位置 state: 运行状态对象 platform: 平台名称 (twitter/reddit) Returns: 新的读取位置 """ # 检查是否启用了图谱记忆更新 graph_memory_enabled = cls._graph_memory_enabled.get(state.simulation_id, False) graph_updater = None if graph_memory_enabled: graph_updater = ZepGraphMemoryManager.get_updater(state.simulation_id) try: with open(log_path, 'r', encoding='utf-8') as f: f.seek(position) for line in f: line = line.strip() if line: try: action_data = json.loads(line) # 处理事件类型的条目 if "event_type" in action_data: event_type = action_data.get("event_type") # 检测 simulation_end 事件,标记平台已完成 if event_type == "simulation_end": if platform == "twitter": state.twitter_completed = True state.twitter_running = False logger.info(f"Twitter 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") elif platform == "reddit": state.reddit_completed = True state.reddit_running = False logger.info(f"Reddit 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") # 检查是否所有启用的平台都已完成 # 如果只运行了一个平台,只检查那个平台 # 如果运行了两个平台,需要两个都完成 all_completed = cls._check_all_platforms_completed(state) if all_completed: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() logger.info(f"所有平台模拟已完成: {state.simulation_id}") # 更新轮次信息(从 round_end 事件) elif event_type == "round_end": round_num = action_data.get("round", 0) simulated_hours = action_data.get("simulated_hours", 0) # 更新各平台独立的轮次和时间 if platform == "twitter": if round_num > state.twitter_current_round: state.twitter_current_round = round_num state.twitter_simulated_hours = simulated_hours elif platform == "reddit": if round_num > state.reddit_current_round: state.reddit_current_round = round_num state.reddit_simulated_hours = simulated_hours # 总体轮次取两个平台的最大值 if round_num > state.current_round: state.current_round = round_num # 总体时间取两个平台的最大值 state.simulated_hours = max(state.twitter_simulated_hours, state.reddit_simulated_hours) continue action = AgentAction( round_num=action_data.get("round", 0), timestamp=action_data.get("timestamp", datetime.now().isoformat()), platform=platform, agent_id=action_data.get("agent_id", 0), agent_name=action_data.get("agent_name", ""), action_type=action_data.get("action_type", ""), action_args=action_data.get("action_args", {}), result=action_data.get("result"), success=action_data.get("success", True), ) state.add_action(action) # 更新轮次 if action.round_num and action.round_num > state.current_round: state.current_round = action.round_num # 如果启用了图谱记忆更新,将活动发送到Zep if graph_updater: graph_updater.add_activity_from_dict(action_data, platform) except json.JSONDecodeError: pass return f.tell() except Exception as e: logger.warning(f"读取动作日志失败: {log_path}, error={e}") return position @classmethod def _check_all_platforms_completed(cls, state: SimulationRunState) -> bool: """ 检查所有启用的平台是否都已完成模拟 通过检查对应的 actions.jsonl 文件是否存在来判断平台是否被启用 Returns: True 如果所有启用的平台都已完成 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) twitter_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_log = os.path.join(sim_dir, "reddit", "actions.jsonl") # 检查哪些平台被启用(通过文件是否存在判断) twitter_enabled = os.path.exists(twitter_log) reddit_enabled = os.path.exists(reddit_log) # 如果平台被启用但未完成,则返回 False if twitter_enabled and not state.twitter_completed: return False if reddit_enabled and not state.reddit_completed: return False # 至少有一个平台被启用且已完成 return twitter_enabled or reddit_enabled @classmethod def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeout: int = 10): """ 跨平台终止进程及其子进程 Args: process: 要终止的进程 simulation_id: 模拟ID(用于日志) timeout: 等待进程退出的超时时间(秒) """ if IS_WINDOWS: # Windows: 使用 taskkill 命令终止进程树 # /F = 强制终止, /T = 终止进程树(包括子进程) logger.info(f"终止进程树 (Windows): simulation={simulation_id}, pid={process.pid}") try: # 先尝试优雅终止 subprocess.run( ['taskkill', '/PID', str(process.pid), '/T'], capture_output=True, timeout=5 ) try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: # 强制终止 logger.warning(f"进程未响应,强制终止: {simulation_id}") subprocess.run( ['taskkill', '/F', '/PID', str(process.pid), '/T'], capture_output=True, timeout=5 ) process.wait(timeout=5) except Exception as e: logger.warning(f"taskkill 失败,尝试 terminate: {e}") process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() else: # Unix: 使用进程组终止 # 由于使用了 start_new_session=True,进程组 ID 等于主进程 PID pgid = os.getpgid(process.pid) logger.info(f"终止进程组 (Unix): simulation={simulation_id}, pgid={pgid}") # 先发送 SIGTERM 给整个进程组 os.killpg(pgid, signal.SIGTERM) try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: # 如果超时后还没结束,强制发送 SIGKILL logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}") os.killpg(pgid, signal.SIGKILL) process.wait(timeout=5) @classmethod def stop_simulation(cls, simulation_id: str) -> SimulationRunState: """停止模拟""" state = cls.get_run_state(simulation_id) if not state: raise ValueError(f"模拟不存在: {simulation_id}") if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]: raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}") state.runner_status = RunnerStatus.STOPPING cls._save_run_state(state) # 终止进程 process = cls._processes.get(simulation_id) if process and process.poll() is None: try: cls._terminate_process(process, simulation_id) except ProcessLookupError: # 进程已经不存在 pass except Exception as e: logger.error(f"终止进程组失败: {simulation_id}, error={e}") # 回退到直接终止进程 try: process.terminate() process.wait(timeout=5) except Exception: process.kill() state.runner_status = RunnerStatus.STOPPED state.twitter_running = False state.reddit_running = False state.completed_at = datetime.now().isoformat() cls._save_run_state(state) # 停止图谱记忆更新器 if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") except Exception as e: logger.error(f"停止图谱记忆更新器失败: {e}") cls._graph_memory_enabled.pop(simulation_id, None) logger.info(f"模拟已停止: {simulation_id}") return state @classmethod def _read_actions_from_file( cls, file_path: str, default_platform: Optional[str] = None, platform_filter: Optional[str] = None, agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: """ 从单个动作文件中读取动作 Args: file_path: 动作日志文件路径 default_platform: 默认平台(当动作记录中没有 platform 字段时使用) platform_filter: 过滤平台 agent_id: 过滤 Agent ID round_num: 过滤轮次 """ if not os.path.exists(file_path): return [] actions = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue try: data = json.loads(line) # 跳过非动作记录(如 simulation_start, round_start, round_end 等事件) if "event_type" in data: continue # 跳过没有 agent_id 的记录(非 Agent 动作) if "agent_id" not in data: continue # 获取平台:优先使用记录中的 platform,否则使用默认平台 record_platform = data.get("platform") or default_platform or "" # 过滤 if platform_filter and record_platform != platform_filter: continue if agent_id is not None and data.get("agent_id") != agent_id: continue if round_num is not None and data.get("round") != round_num: continue actions.append(AgentAction( round_num=data.get("round", 0), timestamp=data.get("timestamp", ""), platform=record_platform, agent_id=data.get("agent_id", 0), agent_name=data.get("agent_name", ""), action_type=data.get("action_type", ""), action_args=data.get("action_args", {}), result=data.get("result"), success=data.get("success", True), )) except json.JSONDecodeError: continue return actions @classmethod def get_all_actions( cls, simulation_id: str, platform: Optional[str] = None, agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: """ 获取所有平台的完整动作历史(无分页限制) Args: simulation_id: 模拟ID platform: 过滤平台(twitter/reddit) agent_id: 过滤Agent round_num: 过滤轮次 Returns: 完整的动作列表(按时间戳排序,新的在前) """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) actions = [] # 读取 Twitter 动作文件(根据文件路径自动设置 platform 为 twitter) twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") if not platform or platform == "twitter": actions.extend(cls._read_actions_from_file( twitter_actions_log, default_platform="twitter", # 自动填充 platform 字段 platform_filter=platform, agent_id=agent_id, round_num=round_num )) # 读取 Reddit 动作文件(根据文件路径自动设置 platform 为 reddit) reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") if not platform or platform == "reddit": actions.extend(cls._read_actions_from_file( reddit_actions_log, default_platform="reddit", # 自动填充 platform 字段 platform_filter=platform, agent_id=agent_id, round_num=round_num )) # 如果分平台文件不存在,尝试读取旧的单一文件格式 if not actions: actions_log = os.path.join(sim_dir, "actions.jsonl") actions = cls._read_actions_from_file( actions_log, default_platform=None, # 旧格式文件中应该有 platform 字段 platform_filter=platform, agent_id=agent_id, round_num=round_num ) # 按时间戳排序(新的在前) actions.sort(key=lambda x: x.timestamp, reverse=True) return actions @classmethod def get_actions( cls, simulation_id: str, limit: int = 100, offset: int = 0, platform: Optional[str] = None, agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: """ 获取动作历史(带分页) Args: simulation_id: 模拟ID limit: 返回数量限制 offset: 偏移量 platform: 过滤平台 agent_id: 过滤Agent round_num: 过滤轮次 Returns: 动作列表 """ actions = cls.get_all_actions( simulation_id=simulation_id, platform=platform, agent_id=agent_id, round_num=round_num ) # 分页 return actions[offset:offset + limit] @classmethod def get_timeline( cls, simulation_id: str, start_round: int = 0, end_round: Optional[int] = None ) -> List[Dict[str, Any]]: """ 获取模拟时间线(按轮次汇总) Args: simulation_id: 模拟ID start_round: 起始轮次 end_round: 结束轮次 Returns: 每轮的汇总信息 """ actions = cls.get_actions(simulation_id, limit=10000) # 按轮次分组 rounds: Dict[int, Dict[str, Any]] = {} for action in actions: round_num = action.round_num if round_num < start_round: continue if end_round is not None and round_num > end_round: continue if round_num not in rounds: rounds[round_num] = { "round_num": round_num, "twitter_actions": 0, "reddit_actions": 0, "active_agents": set(), "action_types": {}, "first_action_time": action.timestamp, "last_action_time": action.timestamp, } r = rounds[round_num] if action.platform == "twitter": r["twitter_actions"] += 1 else: r["reddit_actions"] += 1 r["active_agents"].add(action.agent_id) r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1 r["last_action_time"] = action.timestamp # 转换为列表 result = [] for round_num in sorted(rounds.keys()): r = rounds[round_num] result.append({ "round_num": round_num, "twitter_actions": r["twitter_actions"], "reddit_actions": r["reddit_actions"], "total_actions": r["twitter_actions"] + r["reddit_actions"], "active_agents_count": len(r["active_agents"]), "active_agents": list(r["active_agents"]), "action_types": r["action_types"], "first_action_time": r["first_action_time"], "last_action_time": r["last_action_time"], }) return result @classmethod def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: """ 获取每个Agent的统计信息 Returns: Agent统计列表 """ actions = cls.get_actions(simulation_id, limit=10000) agent_stats: Dict[int, Dict[str, Any]] = {} for action in actions: agent_id = action.agent_id if agent_id not in agent_stats: agent_stats[agent_id] = { "agent_id": agent_id, "agent_name": action.agent_name, "total_actions": 0, "twitter_actions": 0, "reddit_actions": 0, "action_types": {}, "first_action_time": action.timestamp, "last_action_time": action.timestamp, } stats = agent_stats[agent_id] stats["total_actions"] += 1 if action.platform == "twitter": stats["twitter_actions"] += 1 else: stats["reddit_actions"] += 1 stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1 stats["last_action_time"] = action.timestamp # 按总动作数排序 result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True) return result @classmethod def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: """ 清理模拟的运行日志(用于强制重新开始模拟) 会删除以下文件: - run_state.json - twitter/actions.jsonl - reddit/actions.jsonl - simulation.log - stdout.log / stderr.log - twitter_simulation.db(模拟数据库) - reddit_simulation.db(模拟数据库) - env_status.json(环境状态) 注意:不会删除配置文件(simulation_config.json)和 profile 文件 Args: simulation_id: 模拟ID Returns: 清理结果信息 """ import shutil sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): return {"success": True, "message": "模拟目录不存在,无需清理"} cleaned_files = [] errors = [] # 要删除的文件列表(包括数据库文件) files_to_delete = [ "run_state.json", "simulation.log", "stdout.log", "stderr.log", "twitter_simulation.db", # Twitter 平台数据库 "reddit_simulation.db", # Reddit 平台数据库 "env_status.json", # 环境状态文件 ] # 要删除的目录列表(包含动作日志) dirs_to_clean = ["twitter", "reddit"] # 删除文件 for filename in files_to_delete: file_path = os.path.join(sim_dir, filename) if os.path.exists(file_path): try: os.remove(file_path) cleaned_files.append(filename) except Exception as e: errors.append(f"删除 {filename} 失败: {str(e)}") # 清理平台目录中的动作日志 for dir_name in dirs_to_clean: dir_path = os.path.join(sim_dir, dir_name) if os.path.exists(dir_path): actions_file = os.path.join(dir_path, "actions.jsonl") if os.path.exists(actions_file): try: os.remove(actions_file) cleaned_files.append(f"{dir_name}/actions.jsonl") except Exception as e: errors.append(f"删除 {dir_name}/actions.jsonl 失败: {str(e)}") # 清理内存中的运行状态 if simulation_id in cls._run_states: del cls._run_states[simulation_id] logger.info(f"清理模拟日志完成: {simulation_id}, 删除文件: {cleaned_files}") return { "success": len(errors) == 0, "cleaned_files": cleaned_files, "errors": errors if errors else None } # 防止重复清理的标志 _cleanup_done = False @classmethod def cleanup_all_simulations(cls): """ 清理所有运行中的模拟进程 在服务器关闭时调用,确保所有子进程被终止 """ # 防止重复清理 if cls._cleanup_done: return cls._cleanup_done = True # 检查是否有内容需要清理(避免空进程的进程打印无用日志) has_processes = bool(cls._processes) has_updaters = bool(cls._graph_memory_enabled) if not has_processes and not has_updaters: return # 没有需要清理的内容,静默返回 logger.info("正在清理所有模拟进程...") # 首先停止所有图谱记忆更新器(stop_all 内部会打印日志) try: ZepGraphMemoryManager.stop_all() except Exception as e: logger.error(f"停止图谱记忆更新器失败: {e}") cls._graph_memory_enabled.clear() # 复制字典以避免在迭代时修改 processes = list(cls._processes.items()) for simulation_id, process in processes: try: if process.poll() is None: # 进程仍在运行 logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}") try: # 使用跨平台的进程终止方法 cls._terminate_process(process, simulation_id, timeout=5) except (ProcessLookupError, OSError): # 进程可能已经不存在,尝试直接终止 try: process.terminate() process.wait(timeout=3) except Exception: process.kill() # 更新 run_state.json state = cls.get_run_state(simulation_id) if state: state.runner_status = RunnerStatus.STOPPED state.twitter_running = False state.reddit_running = False state.completed_at = datetime.now().isoformat() state.error = "服务器关闭,模拟被终止" cls._save_run_state(state) # 同时更新 state.json,将状态设为 stopped try: sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) state_file = os.path.join(sim_dir, "state.json") logger.info(f"尝试更新 state.json: {state_file}") if os.path.exists(state_file): with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) state_data['status'] = 'stopped' state_data['updated_at'] = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state_data, f, indent=2, ensure_ascii=False) logger.info(f"已更新 state.json 状态为 stopped: {simulation_id}") else: logger.warning(f"state.json 不存在: {state_file}") except Exception as state_err: logger.warning(f"更新 state.json 失败: {simulation_id}, error={state_err}") except Exception as e: logger.error(f"清理进程失败: {simulation_id}, error={e}") # 清理文件句柄 for simulation_id, file_handle in list(cls._stdout_files.items()): try: if file_handle: file_handle.close() except Exception: pass cls._stdout_files.clear() for simulation_id, file_handle in list(cls._stderr_files.items()): try: if file_handle: file_handle.close() except Exception: pass cls._stderr_files.clear() # 清理内存中的状态 cls._processes.clear() cls._action_queues.clear() logger.info("模拟进程清理完成") @classmethod def register_cleanup(cls): """ 注册清理函数 在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程 """ global _cleanup_registered if _cleanup_registered: return # Flask debug 模式下,只在 reloader 子进程中注册清理(实际运行应用的进程) # WERKZEUG_RUN_MAIN=true 表示是 reloader 子进程 # 如果不是 debug 模式,则没有这个环境变量,也需要注册 is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' is_debug_mode = os.environ.get('FLASK_DEBUG') == '1' or os.environ.get('WERKZEUG_RUN_MAIN') is not None # 在 debug 模式下,只在 reloader 子进程中注册;非 debug 模式下始终注册 if is_debug_mode and not is_reloader_process: _cleanup_registered = True # 标记已注册,防止子进程再次尝试 return # 保存原有的信号处理器 original_sigint = signal.getsignal(signal.SIGINT) original_sigterm = signal.getsignal(signal.SIGTERM) # SIGHUP 只在 Unix 系统存在(macOS/Linux),Windows 没有 original_sighup = None has_sighup = hasattr(signal, 'SIGHUP') if has_sighup: original_sighup = signal.getsignal(signal.SIGHUP) def cleanup_handler(signum=None, frame=None): """信号处理器:先清理模拟进程,再调用原处理器""" # 只有在有进程需要清理时才打印日志 if cls._processes or cls._graph_memory_enabled: logger.info(f"收到信号 {signum},开始清理...") cls.cleanup_all_simulations() # 调用原有的信号处理器,让 Flask 正常退出 if signum == signal.SIGINT and callable(original_sigint): original_sigint(signum, frame) elif signum == signal.SIGTERM and callable(original_sigterm): original_sigterm(signum, frame) elif has_sighup and signum == signal.SIGHUP: # SIGHUP: 终端关闭时发送 if callable(original_sighup): original_sighup(signum, frame) else: # 默认行为:正常退出 sys.exit(0) else: # 如果原处理器不可调用(如 SIG_DFL),则使用默认行为 raise KeyboardInterrupt # 注册 atexit 处理器(作为备用) atexit.register(cls.cleanup_all_simulations) # 注册信号处理器(仅在主线程中) try: # SIGTERM: kill 命令默认信号 signal.signal(signal.SIGTERM, cleanup_handler) # SIGINT: Ctrl+C signal.signal(signal.SIGINT, cleanup_handler) # SIGHUP: 终端关闭(仅 Unix 系统) if has_sighup: signal.signal(signal.SIGHUP, cleanup_handler) except ValueError: # 不在主线程中,只能使用 atexit logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit") _cleanup_registered = True @classmethod def get_running_simulations(cls) -> List[str]: """ 获取所有正在运行的模拟ID列表 """ running = [] for sim_id, process in cls._processes.items(): if process.poll() is None: running.append(sim_id) return running # ============== Interview 功能 ============== @classmethod def check_env_alive(cls, simulation_id: str) -> bool: """ 检查模拟环境是否存活(可以接收Interview命令) Args: simulation_id: 模拟ID Returns: True 表示环境存活,False 表示环境已关闭 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): return False ipc_client = SimulationIPCClient(sim_dir) return ipc_client.check_env_alive() @classmethod def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]: """ 获取模拟环境的详细状态信息 Args: simulation_id: 模拟ID Returns: 状态详情字典,包含 status, twitter_available, reddit_available, timestamp """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) status_file = os.path.join(sim_dir, "env_status.json") default_status = { "status": "stopped", "twitter_available": False, "reddit_available": False, "timestamp": None } if not os.path.exists(status_file): return default_status try: with open(status_file, 'r', encoding='utf-8') as f: status = json.load(f) return { "status": status.get("status", "stopped"), "twitter_available": status.get("twitter_available", False), "reddit_available": status.get("reddit_available", False), "timestamp": status.get("timestamp") } except (json.JSONDecodeError, OSError): return default_status @classmethod def interview_agent( cls, simulation_id: str, agent_id: int, prompt: str, platform: str = None, timeout: float = 60.0 ) -> Dict[str, Any]: """ 采访单个Agent Args: simulation_id: 模拟ID agent_id: Agent ID prompt: 采访问题 platform: 指定平台(可选) - "twitter": 只采访Twitter平台 - "reddit": 只采访Reddit平台 - None: 双平台模拟时同时采访两个平台,返回整合结果 timeout: 超时时间(秒) Returns: 采访结果字典 Raises: ValueError: 模拟不存在或环境未运行 TimeoutError: 等待响应超时 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}") response = ipc_client.send_interview( agent_id=agent_id, prompt=prompt, platform=platform, timeout=timeout ) if response.status.value == "completed": return { "success": True, "agent_id": agent_id, "prompt": prompt, "result": response.result, "timestamp": response.timestamp } else: return { "success": False, "agent_id": agent_id, "prompt": prompt, "error": response.error, "timestamp": response.timestamp } @classmethod def interview_agents_batch( cls, simulation_id: str, interviews: List[Dict[str, Any]], platform: str = None, timeout: float = 120.0 ) -> Dict[str, Any]: """ 批量采访多个Agent Args: simulation_id: 模拟ID interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} platform: 默认平台(可选,会被每个采访项的platform覆盖) - "twitter": 默认只采访Twitter平台 - "reddit": 默认只采访Reddit平台 - None: 双平台模拟时每个Agent同时采访两个平台 timeout: 超时时间(秒) Returns: 批量采访结果字典 Raises: ValueError: 模拟不存在或环境未运行 TimeoutError: 等待响应超时 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}") response = ipc_client.send_batch_interview( interviews=interviews, platform=platform, timeout=timeout ) if response.status.value == "completed": return { "success": True, "interviews_count": len(interviews), "result": response.result, "timestamp": response.timestamp } else: return { "success": False, "interviews_count": len(interviews), "error": response.error, "timestamp": response.timestamp } @classmethod def interview_all_agents( cls, simulation_id: str, prompt: str, platform: str = None, timeout: float = 180.0 ) -> Dict[str, Any]: """ 采访所有Agent(全局采访) 使用相同的问题采访模拟中的所有Agent Args: simulation_id: 模拟ID prompt: 采访问题(所有Agent使用相同问题) platform: 指定平台(可选) - "twitter": 只采访Twitter平台 - "reddit": 只采访Reddit平台 - None: 双平台模拟时每个Agent同时采访两个平台 timeout: 超时时间(秒) Returns: 全局采访结果字典 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") # 从配置文件获取所有Agent信息 config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): raise ValueError(f"模拟配置不存在: {simulation_id}") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) agent_configs = config.get("agent_configs", []) if not agent_configs: raise ValueError(f"模拟配置中没有Agent: {simulation_id}") # 构建批量采访列表 interviews = [] for agent_config in agent_configs: agent_id = agent_config.get("agent_id") if agent_id is not None: interviews.append({ "agent_id": agent_id, "prompt": prompt }) logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}") return cls.interview_agents_batch( simulation_id=simulation_id, interviews=interviews, platform=platform, timeout=timeout ) @classmethod def close_simulation_env( cls, simulation_id: str, timeout: float = 30.0 ) -> Dict[str, Any]: """ 关闭模拟环境(而不是停止模拟进程) 向模拟发送关闭环境命令,使其优雅退出等待命令模式 Args: simulation_id: 模拟ID timeout: 超时时间(秒) Returns: 操作结果字典 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): return { "success": True, "message": "环境已经关闭" } logger.info(f"发送关闭环境命令: simulation_id={simulation_id}") try: response = ipc_client.send_close_env(timeout=timeout) return { "success": response.status.value == "completed", "message": "环境关闭命令已发送", "result": response.result, "timestamp": response.timestamp } except TimeoutError: # 超时可能是因为环境正在关闭 return { "success": True, "message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)" } @classmethod def _get_interview_history_from_db( cls, db_path: str, platform_name: str, agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: """从单个数据库获取Interview历史""" import sqlite3 if not os.path.exists(db_path): return [] results = [] try: conn = sqlite3.connect(db_path) cursor = conn.cursor() if agent_id is not None: cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = 'interview' AND user_id = ? ORDER BY created_at DESC LIMIT ? """, (agent_id, limit)) else: cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = 'interview' ORDER BY created_at DESC LIMIT ? """, (limit,)) for user_id, info_json, created_at in cursor.fetchall(): try: info = json.loads(info_json) if info_json else {} except json.JSONDecodeError: info = {"raw": info_json} results.append({ "agent_id": user_id, "response": info.get("response", info), "prompt": info.get("prompt", ""), "timestamp": created_at, "platform": platform_name }) conn.close() except Exception as e: logger.error(f"读取Interview历史失败 ({platform_name}): {e}") return results @classmethod def get_interview_history( cls, simulation_id: str, platform: str = None, agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: """ 获取Interview历史记录(从数据库读取) Args: simulation_id: 模拟ID platform: 平台类型(reddit/twitter/None) - "reddit": 只获取Reddit平台的历史 - "twitter": 只获取Twitter平台的历史 - None: 获取两个平台的所有历史 agent_id: 指定Agent ID(可选,只获取该Agent的历史) limit: 每个平台返回数量限制 Returns: Interview历史记录列表 """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) results = [] # 确定要查询的平台 if platform in ("reddit", "twitter"): platforms = [platform] else: # 不指定platform时,查询两个平台 platforms = ["twitter", "reddit"] for p in platforms: db_path = os.path.join(sim_dir, f"{p}_simulation.db") platform_results = cls._get_interview_history_from_db( db_path=db_path, platform_name=p, agent_id=agent_id, limit=limit ) results.extend(platform_results) # 按时间降序排序 results.sort(key=lambda x: x.get("timestamp", ""), reverse=True) # 如果查询了多个平台,限制总数 if len(platforms) > 1 and len(results) > limit: results = results[:limit] return results ================================================ FILE: backend/app/services/text_processor.py ================================================ """ 文本处理服务 """ from typing import List, Optional from ..utils.file_parser import FileParser, split_text_into_chunks class TextProcessor: """文本处理器""" @staticmethod def extract_from_files(file_paths: List[str]) -> str: """从多个文件提取文本""" return FileParser.extract_from_multiple(file_paths) @staticmethod def split_text( text: str, chunk_size: int = 500, overlap: int = 50 ) -> List[str]: """ 分割文本 Args: text: 原始文本 chunk_size: 块大小 overlap: 重叠大小 Returns: 文本块列表 """ return split_text_into_chunks(text, chunk_size, overlap) @staticmethod def preprocess_text(text: str) -> str: """ 预处理文本 - 移除多余空白 - 标准化换行 Args: text: 原始文本 Returns: 处理后的文本 """ import re # 标准化换行 text = text.replace('\r\n', '\n').replace('\r', '\n') # 移除连续空行(保留最多两个换行) text = re.sub(r'\n{3,}', '\n\n', text) # 移除行首行尾空白 lines = [line.strip() for line in text.split('\n')] text = '\n'.join(lines) return text.strip() @staticmethod def get_text_stats(text: str) -> dict: """获取文本统计信息""" return { "total_chars": len(text), "total_lines": text.count('\n') + 1, "total_words": len(text.split()), } ================================================ FILE: backend/app/services/zep_entity_reader.py ================================================ """ Zep实体读取与过滤服务 从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 """ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_entity_reader') # 用于泛型返回类型 T = TypeVar('T') @dataclass class EntityNode: """实体节点数据结构""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] # 相关的边信息 related_edges: List[Dict[str, Any]] = field(default_factory=list) # 相关的其他节点信息 related_nodes: List[Dict[str, Any]] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "labels": self.labels, "summary": self.summary, "attributes": self.attributes, "related_edges": self.related_edges, "related_nodes": self.related_nodes, } def get_entity_type(self) -> Optional[str]: """获取实体类型(排除默认的Entity标签)""" for label in self.labels: if label not in ["Entity", "Node"]: return label return None @dataclass class FilteredEntities: """过滤后的实体集合""" entities: List[EntityNode] entity_types: Set[str] total_count: int filtered_count: int def to_dict(self) -> Dict[str, Any]: return { "entities": [e.to_dict() for e in self.entities], "entity_types": list(self.entity_types), "total_count": self.total_count, "filtered_count": self.filtered_count, } class ZepEntityReader: """ Zep实体读取与过滤服务 主要功能: 1. 从Zep图谱读取所有节点 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) 3. 获取每个实体的相关边和关联节点信息 """ def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") self.client = Zep(api_key=self.api_key) def _call_with_retry( self, func: Callable[[], T], operation_name: str, max_retries: int = 3, initial_delay: float = 2.0 ) -> T: """ 带重试机制的Zep API调用 Args: func: 要执行的函数(无参数的lambda或callable) operation_name: 操作名称,用于日志 max_retries: 最大重试次数(默认3次,即最多尝试3次) initial_delay: 初始延迟秒数 Returns: API调用结果 """ last_exception = None delay = initial_delay for attempt in range(max_retries): try: return func() except Exception as e: last_exception = e if attempt < max_retries - 1: logger.warning( f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " f"{delay:.1f}秒后重试..." ) time.sleep(delay) delay *= 2 # 指数退避 else: logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") raise last_exception def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ 获取图谱的所有节点(分页获取) Args: graph_id: 图谱ID Returns: 节点列表 """ logger.info(f"获取图谱 {graph_id} 的所有节点...") nodes = fetch_all_nodes(self.client, graph_id) nodes_data = [] for node in nodes: nodes_data.append({ "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), "name": node.name or "", "labels": node.labels or [], "summary": node.summary or "", "attributes": node.attributes or {}, }) logger.info(f"共获取 {len(nodes_data)} 个节点") return nodes_data def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ 获取图谱的所有边(分页获取) Args: graph_id: 图谱ID Returns: 边列表 """ logger.info(f"获取图谱 {graph_id} 的所有边...") edges = fetch_all_edges(self.client, graph_id) edges_data = [] for edge in edges: edges_data.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "attributes": edge.attributes or {}, }) logger.info(f"共获取 {len(edges_data)} 条边") return edges_data def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: """ 获取指定节点的所有相关边(带重试机制) Args: node_uuid: 节点UUID Returns: 边列表 """ try: # 使用重试机制调用Zep API edges = self._call_with_retry( func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) edges_data = [] for edge in edges: edges_data.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "attributes": edge.attributes or {}, }) return edges_data except Exception as e: logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") return [] def filter_defined_entities( self, graph_id: str, defined_entity_types: Optional[List[str]] = None, enrich_with_edges: bool = True ) -> FilteredEntities: """ 筛选出符合预定义实体类型的节点 筛选逻辑: - 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过 - 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留 Args: graph_id: 图谱ID defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型) enrich_with_edges: 是否获取每个实体的相关边信息 Returns: FilteredEntities: 过滤后的实体集合 """ logger.info(f"开始筛选图谱 {graph_id} 的实体...") # 获取所有节点 all_nodes = self.get_all_nodes(graph_id) total_count = len(all_nodes) # 获取所有边(用于后续关联查找) all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] # 构建节点UUID到节点数据的映射 node_map = {n["uuid"]: n for n in all_nodes} # 筛选符合条件的实体 filtered_entities = [] entity_types_found = set() for node in all_nodes: labels = node.get("labels", []) # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 custom_labels = [l for l in labels if l not in ["Entity", "Node"]] if not custom_labels: # 只有默认标签,跳过 continue # 如果指定了预定义类型,检查是否匹配 if defined_entity_types: matching_labels = [l for l in custom_labels if l in defined_entity_types] if not matching_labels: continue entity_type = matching_labels[0] else: entity_type = custom_labels[0] entity_types_found.add(entity_type) # 创建实体节点对象 entity = EntityNode( uuid=node["uuid"], name=node["name"], labels=labels, summary=node["summary"], attributes=node["attributes"], ) # 获取相关边和节点 if enrich_with_edges: related_edges = [] related_node_uuids = set() for edge in all_edges: if edge["source_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) elif edge["target_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) entity.related_edges = related_edges # 获取关联节点的基本信息 related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) entity.related_nodes = related_nodes filtered_entities.append(entity) logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, " f"实体类型: {entity_types_found}") return FilteredEntities( entities=filtered_entities, entity_types=entity_types_found, total_count=total_count, filtered_count=len(filtered_entities), ) def get_entity_with_context( self, graph_id: str, entity_uuid: str ) -> Optional[EntityNode]: """ 获取单个实体及其完整上下文(边和关联节点,带重试机制) Args: graph_id: 图谱ID entity_uuid: 实体UUID Returns: EntityNode或None """ try: # 使用重试机制获取节点 node = self._call_with_retry( func=lambda: self.client.graph.node.get(uuid_=entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) if not node: return None # 获取节点的边 edges = self.get_node_edges(entity_uuid) # 获取所有节点用于关联查找 all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} # 处理相关边和节点 related_edges = [] related_node_uuids = set() for edge in edges: if edge["source_node_uuid"] == entity_uuid: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) else: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) # 获取关联节点信息 related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) return EntityNode( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {}, related_edges=related_edges, related_nodes=related_nodes, ) except Exception as e: logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") return None def get_entities_by_type( self, graph_id: str, entity_type: str, enrich_with_edges: bool = True ) -> List[EntityNode]: """ 获取指定类型的所有实体 Args: graph_id: 图谱ID entity_type: 实体类型(如 "Student", "PublicFigure" 等) enrich_with_edges: 是否获取相关边信息 Returns: 实体列表 """ result = self.filter_defined_entities( graph_id=graph_id, defined_entity_types=[entity_type], enrich_with_edges=enrich_with_edges ) return result.entities ================================================ FILE: backend/app/services/zep_graph_memory_updater.py ================================================ """ Zep图谱记忆更新服务 将模拟中的Agent活动动态更新到Zep图谱中 """ import os import time import threading import json from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass from datetime import datetime from queue import Queue, Empty from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger logger = get_logger('mirofish.zep_graph_memory_updater') @dataclass class AgentActivity: """Agent活动记录""" platform: str # twitter / reddit agent_id: int agent_name: str action_type: str # CREATE_POST, LIKE_POST, etc. action_args: Dict[str, Any] round_num: int timestamp: str def to_episode_text(self) -> str: """ 将活动转换为可以发送给Zep的文本描述 采用自然语言描述格式,让Zep能够从中提取实体和关系 不添加模拟相关的前缀,避免误导图谱更新 """ # 根据不同的动作类型生成不同的描述 action_descriptions = { "CREATE_POST": self._describe_create_post, "LIKE_POST": self._describe_like_post, "DISLIKE_POST": self._describe_dislike_post, "REPOST": self._describe_repost, "QUOTE_POST": self._describe_quote_post, "FOLLOW": self._describe_follow, "CREATE_COMMENT": self._describe_create_comment, "LIKE_COMMENT": self._describe_like_comment, "DISLIKE_COMMENT": self._describe_dislike_comment, "SEARCH_POSTS": self._describe_search, "SEARCH_USER": self._describe_search_user, "MUTE": self._describe_mute, } describe_func = action_descriptions.get(self.action_type, self._describe_generic) description = describe_func() # 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀 return f"{self.agent_name}: {description}" def _describe_create_post(self) -> str: content = self.action_args.get("content", "") if content: return f"发布了一条帖子:「{content}」" return "发布了一条帖子" def _describe_like_post(self) -> str: """点赞帖子 - 包含帖子原文和作者信息""" post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") if post_content and post_author: return f"点赞了{post_author}的帖子:「{post_content}」" elif post_content: return f"点赞了一条帖子:「{post_content}」" elif post_author: return f"点赞了{post_author}的一条帖子" return "点赞了一条帖子" def _describe_dislike_post(self) -> str: """踩帖子 - 包含帖子原文和作者信息""" post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") if post_content and post_author: return f"踩了{post_author}的帖子:「{post_content}」" elif post_content: return f"踩了一条帖子:「{post_content}」" elif post_author: return f"踩了{post_author}的一条帖子" return "踩了一条帖子" def _describe_repost(self) -> str: """转发帖子 - 包含原帖内容和作者信息""" original_content = self.action_args.get("original_content", "") original_author = self.action_args.get("original_author_name", "") if original_content and original_author: return f"转发了{original_author}的帖子:「{original_content}」" elif original_content: return f"转发了一条帖子:「{original_content}」" elif original_author: return f"转发了{original_author}的一条帖子" return "转发了一条帖子" def _describe_quote_post(self) -> str: """引用帖子 - 包含原帖内容、作者信息和引用评论""" original_content = self.action_args.get("original_content", "") original_author = self.action_args.get("original_author_name", "") quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "") base = "" if original_content and original_author: base = f"引用了{original_author}的帖子「{original_content}」" elif original_content: base = f"引用了一条帖子「{original_content}」" elif original_author: base = f"引用了{original_author}的一条帖子" else: base = "引用了一条帖子" if quote_content: base += f",并评论道:「{quote_content}」" return base def _describe_follow(self) -> str: """关注用户 - 包含被关注用户的名称""" target_user_name = self.action_args.get("target_user_name", "") if target_user_name: return f"关注了用户「{target_user_name}」" return "关注了一个用户" def _describe_create_comment(self) -> str: """发表评论 - 包含评论内容和所评论的帖子信息""" content = self.action_args.get("content", "") post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") if content: if post_content and post_author: return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」" elif post_content: return f"在帖子「{post_content}」下评论道:「{content}」" elif post_author: return f"在{post_author}的帖子下评论道:「{content}」" return f"评论道:「{content}」" return "发表了评论" def _describe_like_comment(self) -> str: """点赞评论 - 包含评论内容和作者信息""" comment_content = self.action_args.get("comment_content", "") comment_author = self.action_args.get("comment_author_name", "") if comment_content and comment_author: return f"点赞了{comment_author}的评论:「{comment_content}」" elif comment_content: return f"点赞了一条评论:「{comment_content}」" elif comment_author: return f"点赞了{comment_author}的一条评论" return "点赞了一条评论" def _describe_dislike_comment(self) -> str: """踩评论 - 包含评论内容和作者信息""" comment_content = self.action_args.get("comment_content", "") comment_author = self.action_args.get("comment_author_name", "") if comment_content and comment_author: return f"踩了{comment_author}的评论:「{comment_content}」" elif comment_content: return f"踩了一条评论:「{comment_content}」" elif comment_author: return f"踩了{comment_author}的一条评论" return "踩了一条评论" def _describe_search(self) -> str: """搜索帖子 - 包含搜索关键词""" query = self.action_args.get("query", "") or self.action_args.get("keyword", "") return f"搜索了「{query}」" if query else "进行了搜索" def _describe_search_user(self) -> str: """搜索用户 - 包含搜索关键词""" query = self.action_args.get("query", "") or self.action_args.get("username", "") return f"搜索了用户「{query}」" if query else "搜索了用户" def _describe_mute(self) -> str: """屏蔽用户 - 包含被屏蔽用户的名称""" target_user_name = self.action_args.get("target_user_name", "") if target_user_name: return f"屏蔽了用户「{target_user_name}」" return "屏蔽了一个用户" def _describe_generic(self) -> str: # 对于未知的动作类型,生成通用描述 return f"执行了{self.action_type}操作" class ZepGraphMemoryUpdater: """ Zep图谱记忆更新器 监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。 按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。 所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息: - 点赞/踩的帖子原文 - 转发/引用的帖子原文 - 关注/屏蔽的用户名 - 点赞/踩的评论原文 """ # 批量发送大小(每个平台累积多少条后发送) BATCH_SIZE = 5 # 平台名称映射(用于控制台显示) PLATFORM_DISPLAY_NAMES = { 'twitter': '世界1', 'reddit': '世界2', } # 发送间隔(秒),避免请求过快 SEND_INTERVAL = 0.5 # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2 # 秒 def __init__(self, graph_id: str, api_key: Optional[str] = None): """ 初始化更新器 Args: graph_id: Zep图谱ID api_key: Zep API Key(可选,默认从配置读取) """ self.graph_id = graph_id self.api_key = api_key or Config.ZEP_API_KEY if not self.api_key: raise ValueError("ZEP_API_KEY未配置") self.client = Zep(api_key=self.api_key) # 活动队列 self._activity_queue: Queue = Queue() # 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送) self._platform_buffers: Dict[str, List[AgentActivity]] = { 'twitter': [], 'reddit': [], } self._buffer_lock = threading.Lock() # 控制标志 self._running = False self._worker_thread: Optional[threading.Thread] = None # 统计 self._total_activities = 0 # 实际添加到队列的活动数 self._total_sent = 0 # 成功发送到Zep的批次数 self._total_items_sent = 0 # 成功发送到Zep的活动条数 self._failed_count = 0 # 发送失败的批次数 self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING) logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}") def _get_platform_display_name(self, platform: str) -> str: """获取平台的显示名称""" return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform) def start(self): """启动后台工作线程""" if self._running: return self._running = True self._worker_thread = threading.Thread( target=self._worker_loop, daemon=True, name=f"ZepMemoryUpdater-{self.graph_id[:8]}" ) self._worker_thread.start() logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}") def stop(self): """停止后台工作线程""" self._running = False # 发送剩余的活动 self._flush_remaining() if self._worker_thread and self._worker_thread.is_alive(): self._worker_thread.join(timeout=10) logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, " f"total_activities={self._total_activities}, " f"batches_sent={self._total_sent}, " f"items_sent={self._total_items_sent}, " f"failed={self._failed_count}, " f"skipped={self._skipped_count}") def add_activity(self, activity: AgentActivity): """ 添加一个agent活动到队列 所有有意义的行为都会被添加到队列,包括: - CREATE_POST(发帖) - CREATE_COMMENT(评论) - QUOTE_POST(引用帖子) - SEARCH_POSTS(搜索帖子) - SEARCH_USER(搜索用户) - LIKE_POST/DISLIKE_POST(点赞/踩帖子) - REPOST(转发) - FOLLOW(关注) - MUTE(屏蔽) - LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论) action_args中会包含完整的上下文信息(如帖子原文、用户名等)。 Args: activity: Agent活动记录 """ # 跳过DO_NOTHING类型的活动 if activity.action_type == "DO_NOTHING": self._skipped_count += 1 return self._activity_queue.put(activity) self._total_activities += 1 logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}") def add_activity_from_dict(self, data: Dict[str, Any], platform: str): """ 从字典数据添加活动 Args: data: 从actions.jsonl解析的字典数据 platform: 平台名称 (twitter/reddit) """ # 跳过事件类型的条目 if "event_type" in data: return activity = AgentActivity( platform=platform, agent_id=data.get("agent_id", 0), agent_name=data.get("agent_name", ""), action_type=data.get("action_type", ""), action_args=data.get("action_args", {}), round_num=data.get("round", 0), timestamp=data.get("timestamp", datetime.now().isoformat()), ) self.add_activity(activity) def _worker_loop(self): """后台工作循环 - 按平台批量发送活动到Zep""" while self._running or not self._activity_queue.empty(): try: # 尝试从队列获取活动(超时1秒) try: activity = self._activity_queue.get(timeout=1) # 将活动添加到对应平台的缓冲区 platform = activity.platform.lower() with self._buffer_lock: if platform not in self._platform_buffers: self._platform_buffers[platform] = [] self._platform_buffers[platform].append(activity) # 检查该平台是否达到批量大小 if len(self._platform_buffers[platform]) >= self.BATCH_SIZE: batch = self._platform_buffers[platform][:self.BATCH_SIZE] self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:] # 释放锁后再发送 self._send_batch_activities(batch, platform) # 发送间隔,避免请求过快 time.sleep(self.SEND_INTERVAL) except Empty: pass except Exception as e: logger.error(f"工作循环异常: {e}") time.sleep(1) def _send_batch_activities(self, activities: List[AgentActivity], platform: str): """ 批量发送活动到Zep图谱(合并为一条文本) Args: activities: Agent活动列表 platform: 平台名称 """ if not activities: return # 将多条活动合并为一条文本,用换行分隔 episode_texts = [activity.to_episode_text() for activity in activities] combined_text = "\n".join(episode_texts) # 带重试的发送 for attempt in range(self.MAX_RETRIES): try: self.client.graph.add( graph_id=self.graph_id, type="text", data=combined_text ) self._total_sent += 1 self._total_items_sent += len(activities) display_name = self._get_platform_display_name(platform) logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}") logger.debug(f"批量内容预览: {combined_text[:200]}...") return except Exception as e: if attempt < self.MAX_RETRIES - 1: logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") time.sleep(self.RETRY_DELAY * (attempt + 1)) else: logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}") self._failed_count += 1 def _flush_remaining(self): """发送队列和缓冲区中剩余的活动""" # 首先处理队列中剩余的活动,添加到缓冲区 while not self._activity_queue.empty(): try: activity = self._activity_queue.get_nowait() platform = activity.platform.lower() with self._buffer_lock: if platform not in self._platform_buffers: self._platform_buffers[platform] = [] self._platform_buffers[platform].append(activity) except Empty: break # 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条) with self._buffer_lock: for platform, buffer in self._platform_buffers.items(): if buffer: display_name = self._get_platform_display_name(platform) logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动") self._send_batch_activities(buffer, platform) # 清空所有缓冲区 for platform in self._platform_buffers: self._platform_buffers[platform] = [] def get_stats(self) -> Dict[str, Any]: """获取统计信息""" with self._buffer_lock: buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()} return { "graph_id": self.graph_id, "batch_size": self.BATCH_SIZE, "total_activities": self._total_activities, # 添加到队列的活动总数 "batches_sent": self._total_sent, # 成功发送的批次数 "items_sent": self._total_items_sent, # 成功发送的活动条数 "failed_count": self._failed_count, # 发送失败的批次数 "skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING) "queue_size": self._activity_queue.qsize(), "buffer_sizes": buffer_sizes, # 各平台缓冲区大小 "running": self._running, } class ZepGraphMemoryManager: """ 管理多个模拟的Zep图谱记忆更新器 每个模拟可以有自己的更新器实例 """ _updaters: Dict[str, ZepGraphMemoryUpdater] = {} _lock = threading.Lock() @classmethod def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater: """ 为模拟创建图谱记忆更新器 Args: simulation_id: 模拟ID graph_id: Zep图谱ID Returns: ZepGraphMemoryUpdater实例 """ with cls._lock: # 如果已存在,先停止旧的 if simulation_id in cls._updaters: cls._updaters[simulation_id].stop() updater = ZepGraphMemoryUpdater(graph_id) updater.start() cls._updaters[simulation_id] = updater logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}") return updater @classmethod def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]: """获取模拟的更新器""" return cls._updaters.get(simulation_id) @classmethod def stop_updater(cls, simulation_id: str): """停止并移除模拟的更新器""" with cls._lock: if simulation_id in cls._updaters: cls._updaters[simulation_id].stop() del cls._updaters[simulation_id] logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}") # 防止 stop_all 重复调用的标志 _stop_all_done = False @classmethod def stop_all(cls): """停止所有更新器""" # 防止重复调用 if cls._stop_all_done: return cls._stop_all_done = True with cls._lock: if cls._updaters: for simulation_id, updater in list(cls._updaters.items()): try: updater.stop() except Exception as e: logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}") cls._updaters.clear() logger.info("已停止所有图谱记忆更新器") @classmethod def get_all_stats(cls) -> Dict[str, Dict[str, Any]]: """获取所有更新器的统计信息""" return { sim_id: updater.get_stats() for sim_id, updater in cls._updaters.items() } ================================================ FILE: backend/app/services/zep_tools.py ================================================ """ Zep检索工具服务 封装图谱搜索、节点读取、边查询等工具,供Report Agent使用 核心检索工具(优化后): 1. InsightForge(深度洞察检索)- 最强大的混合检索,自动生成子问题并多维度检索 2. PanoramaSearch(广度搜索)- 获取全貌,包括过期内容 3. QuickSearch(简单搜索)- 快速检索 """ import time import json from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from ..utils.llm_client import LLMClient from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_tools') @dataclass class SearchResult: """搜索结果""" facts: List[str] edges: List[Dict[str, Any]] nodes: List[Dict[str, Any]] query: str total_count: int def to_dict(self) -> Dict[str, Any]: return { "facts": self.facts, "edges": self.edges, "nodes": self.nodes, "query": self.query, "total_count": self.total_count } def to_text(self) -> str: """转换为文本格式,供LLM理解""" text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"] if self.facts: text_parts.append("\n### 相关事实:") for i, fact in enumerate(self.facts, 1): text_parts.append(f"{i}. {fact}") return "\n".join(text_parts) @dataclass class NodeInfo: """节点信息""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "labels": self.labels, "summary": self.summary, "attributes": self.attributes } def to_text(self) -> str: """转换为文本格式""" entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}" @dataclass class EdgeInfo: """边信息""" uuid: str name: str fact: str source_node_uuid: str target_node_uuid: str source_node_name: Optional[str] = None target_node_name: Optional[str] = None # 时间信息 created_at: Optional[str] = None valid_at: Optional[str] = None invalid_at: Optional[str] = None expired_at: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "fact": self.fact, "source_node_uuid": self.source_node_uuid, "target_node_uuid": self.target_node_uuid, "source_node_name": self.source_node_name, "target_node_name": self.target_node_name, "created_at": self.created_at, "valid_at": self.valid_at, "invalid_at": self.invalid_at, "expired_at": self.expired_at } def to_text(self, include_temporal: bool = False) -> str: """转换为文本格式""" source = self.source_node_name or self.source_node_uuid[:8] target = self.target_node_name or self.target_node_uuid[:8] base_text = f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}" if include_temporal: valid_at = self.valid_at or "未知" invalid_at = self.invalid_at or "至今" base_text += f"\n时效: {valid_at} - {invalid_at}" if self.expired_at: base_text += f" (已过期: {self.expired_at})" return base_text @property def is_expired(self) -> bool: """是否已过期""" return self.expired_at is not None @property def is_invalid(self) -> bool: """是否已失效""" return self.invalid_at is not None @dataclass class InsightForgeResult: """ 深度洞察检索结果 (InsightForge) 包含多个子问题的检索结果,以及综合分析 """ query: str simulation_requirement: str sub_queries: List[str] # 各维度检索结果 semantic_facts: List[str] = field(default_factory=list) # 语义搜索结果 entity_insights: List[Dict[str, Any]] = field(default_factory=list) # 实体洞察 relationship_chains: List[str] = field(default_factory=list) # 关系链 # 统计信息 total_facts: int = 0 total_entities: int = 0 total_relationships: int = 0 def to_dict(self) -> Dict[str, Any]: return { "query": self.query, "simulation_requirement": self.simulation_requirement, "sub_queries": self.sub_queries, "semantic_facts": self.semantic_facts, "entity_insights": self.entity_insights, "relationship_chains": self.relationship_chains, "total_facts": self.total_facts, "total_entities": self.total_entities, "total_relationships": self.total_relationships } def to_text(self) -> str: """转换为详细的文本格式,供LLM理解""" text_parts = [ f"## 未来预测深度分析", f"分析问题: {self.query}", f"预测场景: {self.simulation_requirement}", f"\n### 预测数据统计", f"- 相关预测事实: {self.total_facts}条", f"- 涉及实体: {self.total_entities}个", f"- 关系链: {self.total_relationships}条" ] # 子问题 if self.sub_queries: text_parts.append(f"\n### 分析的子问题") for i, sq in enumerate(self.sub_queries, 1): text_parts.append(f"{i}. {sq}") # 语义搜索结果 if self.semantic_facts: text_parts.append(f"\n### 【关键事实】(请在报告中引用这些原文)") for i, fact in enumerate(self.semantic_facts, 1): text_parts.append(f"{i}. \"{fact}\"") # 实体洞察 if self.entity_insights: text_parts.append(f"\n### 【核心实体】") for entity in self.entity_insights: text_parts.append(f"- **{entity.get('name', '未知')}** ({entity.get('type', '实体')})") if entity.get('summary'): text_parts.append(f" 摘要: \"{entity.get('summary')}\"") if entity.get('related_facts'): text_parts.append(f" 相关事实: {len(entity.get('related_facts', []))}条") # 关系链 if self.relationship_chains: text_parts.append(f"\n### 【关系链】") for chain in self.relationship_chains: text_parts.append(f"- {chain}") return "\n".join(text_parts) @dataclass class PanoramaResult: """ 广度搜索结果 (Panorama) 包含所有相关信息,包括过期内容 """ query: str # 全部节点 all_nodes: List[NodeInfo] = field(default_factory=list) # 全部边(包括过期的) all_edges: List[EdgeInfo] = field(default_factory=list) # 当前有效的事实 active_facts: List[str] = field(default_factory=list) # 已过期/失效的事实(历史记录) historical_facts: List[str] = field(default_factory=list) # 统计 total_nodes: int = 0 total_edges: int = 0 active_count: int = 0 historical_count: int = 0 def to_dict(self) -> Dict[str, Any]: return { "query": self.query, "all_nodes": [n.to_dict() for n in self.all_nodes], "all_edges": [e.to_dict() for e in self.all_edges], "active_facts": self.active_facts, "historical_facts": self.historical_facts, "total_nodes": self.total_nodes, "total_edges": self.total_edges, "active_count": self.active_count, "historical_count": self.historical_count } def to_text(self) -> str: """转换为文本格式(完整版本,不截断)""" text_parts = [ f"## 广度搜索结果(未来全景视图)", f"查询: {self.query}", f"\n### 统计信息", f"- 总节点数: {self.total_nodes}", f"- 总边数: {self.total_edges}", f"- 当前有效事实: {self.active_count}条", f"- 历史/过期事实: {self.historical_count}条" ] # 当前有效的事实(完整输出,不截断) if self.active_facts: text_parts.append(f"\n### 【当前有效事实】(模拟结果原文)") for i, fact in enumerate(self.active_facts, 1): text_parts.append(f"{i}. \"{fact}\"") # 历史/过期事实(完整输出,不截断) if self.historical_facts: text_parts.append(f"\n### 【历史/过期事实】(演变过程记录)") for i, fact in enumerate(self.historical_facts, 1): text_parts.append(f"{i}. \"{fact}\"") # 关键实体(完整输出,不截断) if self.all_nodes: text_parts.append(f"\n### 【涉及实体】") for node in self.all_nodes: entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") text_parts.append(f"- **{node.name}** ({entity_type})") return "\n".join(text_parts) @dataclass class AgentInterview: """单个Agent的采访结果""" agent_name: str agent_role: str # 角色类型(如:学生、教师、媒体等) agent_bio: str # 简介 question: str # 采访问题 response: str # 采访回答 key_quotes: List[str] = field(default_factory=list) # 关键引言 def to_dict(self) -> Dict[str, Any]: return { "agent_name": self.agent_name, "agent_role": self.agent_role, "agent_bio": self.agent_bio, "question": self.question, "response": self.response, "key_quotes": self.key_quotes } def to_text(self) -> str: text = f"**{self.agent_name}** ({self.agent_role})\n" # 显示完整的agent_bio,不截断 text += f"_简介: {self.agent_bio}_\n\n" text += f"**Q:** {self.question}\n\n" text += f"**A:** {self.response}\n" if self.key_quotes: text += "\n**关键引言:**\n" for quote in self.key_quotes: # 清理各种引号 clean_quote = quote.replace('\u201c', '').replace('\u201d', '').replace('"', '') clean_quote = clean_quote.replace('\u300c', '').replace('\u300d', '') clean_quote = clean_quote.strip() # 去掉开头的标点 while clean_quote and clean_quote[0] in ',,;;::、。!?\n\r\t ': clean_quote = clean_quote[1:] # 过滤包含问题编号的垃圾内容(问题1-9) skip = False for d in '123456789': if f'\u95ee\u9898{d}' in clean_quote: skip = True break if skip: continue # 截断过长内容(按句号截断,而非硬截断) if len(clean_quote) > 150: dot_pos = clean_quote.find('\u3002', 80) if dot_pos > 0: clean_quote = clean_quote[:dot_pos + 1] else: clean_quote = clean_quote[:147] + "..." if clean_quote and len(clean_quote) >= 10: text += f'> "{clean_quote}"\n' return text @dataclass class InterviewResult: """ 采访结果 (Interview) 包含多个模拟Agent的采访回答 """ interview_topic: str # 采访主题 interview_questions: List[str] # 采访问题列表 # 采访选择的Agent selected_agents: List[Dict[str, Any]] = field(default_factory=list) # 各Agent的采访回答 interviews: List[AgentInterview] = field(default_factory=list) # 选择Agent的理由 selection_reasoning: str = "" # 整合后的采访摘要 summary: str = "" # 统计 total_agents: int = 0 interviewed_count: int = 0 def to_dict(self) -> Dict[str, Any]: return { "interview_topic": self.interview_topic, "interview_questions": self.interview_questions, "selected_agents": self.selected_agents, "interviews": [i.to_dict() for i in self.interviews], "selection_reasoning": self.selection_reasoning, "summary": self.summary, "total_agents": self.total_agents, "interviewed_count": self.interviewed_count } def to_text(self) -> str: """转换为详细的文本格式,供LLM理解和报告引用""" text_parts = [ "## 深度采访报告", f"**采访主题:** {self.interview_topic}", f"**采访人数:** {self.interviewed_count} / {self.total_agents} 位模拟Agent", "\n### 采访对象选择理由", self.selection_reasoning or "(自动选择)", "\n---", "\n### 采访实录", ] if self.interviews: for i, interview in enumerate(self.interviews, 1): text_parts.append(f"\n#### 采访 #{i}: {interview.agent_name}") text_parts.append(interview.to_text()) text_parts.append("\n---") else: text_parts.append("(无采访记录)\n\n---") text_parts.append("\n### 采访摘要与核心观点") text_parts.append(self.summary or "(无摘要)") return "\n".join(text_parts) class ZepToolsService: """ Zep检索工具服务 【核心检索工具 - 优化后】 1. insight_forge - 深度洞察检索(最强大,自动生成子问题,多维度检索) 2. panorama_search - 广度搜索(获取全貌,包括过期内容) 3. quick_search - 简单搜索(快速检索) 4. interview_agents - 深度采访(采访模拟Agent,获取多视角观点) 【基础工具】 - search_graph - 图谱语义搜索 - get_all_nodes - 获取图谱所有节点 - get_all_edges - 获取图谱所有边(含时间信息) - get_node_detail - 获取节点详细信息 - get_node_edges - 获取节点相关的边 - get_entities_by_type - 按类型获取实体 - get_entity_summary - 获取实体的关系摘要 """ # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): self.api_key = api_key or Config.ZEP_API_KEY if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") self.client = Zep(api_key=self.api_key) # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client logger.info("ZepToolsService 初始化完成") @property def llm(self) -> LLMClient: """延迟初始化LLM客户端""" if self._llm_client is None: self._llm_client = LLMClient() return self._llm_client def _call_with_retry(self, func, operation_name: str, max_retries: int = None): """带重试机制的API调用""" max_retries = max_retries or self.MAX_RETRIES last_exception = None delay = self.RETRY_DELAY for attempt in range(max_retries): try: return func() except Exception as e: last_exception = e if attempt < max_retries - 1: logger.warning( f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " f"{delay:.1f}秒后重试..." ) time.sleep(delay) delay *= 2 else: logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") raise last_exception def search_graph( self, graph_id: str, query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ 图谱语义搜索 使用混合搜索(语义+BM25)在图谱中搜索相关信息。 如果Zep Cloud的search API不可用,则降级为本地关键词匹配。 Args: graph_id: 图谱ID (Standalone Graph) query: 搜索查询 limit: 返回结果数量 scope: 搜索范围,"edges" 或 "nodes" Returns: SearchResult: 搜索结果 """ logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") # 尝试使用Zep Cloud Search API try: search_results = self._call_with_retry( func=lambda: self.client.graph.search( graph_id=graph_id, query=query, limit=limit, scope=scope, reranker="cross_encoder" ), operation_name=f"图谱搜索(graph={graph_id})" ) facts = [] edges = [] nodes = [] # 解析边搜索结果 if hasattr(search_results, 'edges') and search_results.edges: for edge in search_results.edges: if hasattr(edge, 'fact') and edge.fact: facts.append(edge.fact) edges.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": getattr(edge, 'name', ''), "fact": getattr(edge, 'fact', ''), "source_node_uuid": getattr(edge, 'source_node_uuid', ''), "target_node_uuid": getattr(edge, 'target_node_uuid', ''), }) # 解析节点搜索结果 if hasattr(search_results, 'nodes') and search_results.nodes: for node in search_results.nodes: nodes.append({ "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), "name": getattr(node, 'name', ''), "labels": getattr(node, 'labels', []), "summary": getattr(node, 'summary', ''), }) # 节点摘要也算作事实 if hasattr(node, 'summary') and node.summary: facts.append(f"[{node.name}]: {node.summary}") logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") return SearchResult( facts=facts, edges=edges, nodes=nodes, query=query, total_count=len(facts) ) except Exception as e: logger.warning(f"Zep Search API失败,降级为本地搜索: {str(e)}") # 降级:使用本地关键词匹配搜索 return self._local_search(graph_id, query, limit, scope) def _local_search( self, graph_id: str, query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ 本地关键词匹配搜索(作为Zep Search API的降级方案) 获取所有边/节点,然后在本地进行关键词匹配 Args: graph_id: 图谱ID query: 搜索查询 limit: 返回结果数量 scope: 搜索范围 Returns: SearchResult: 搜索结果 """ logger.info(f"使用本地搜索: query={query[:30]}...") facts = [] edges_result = [] nodes_result = [] # 提取查询关键词(简单分词) query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] def match_score(text: str) -> int: """计算文本与查询的匹配分数""" if not text: return 0 text_lower = text.lower() # 完全匹配查询 if query_lower in text_lower: return 100 # 关键词匹配 score = 0 for keyword in keywords: if keyword in text_lower: score += 10 return score try: if scope in ["edges", "both"]: # 获取所有边并匹配 all_edges = self.get_all_edges(graph_id) scored_edges = [] for edge in all_edges: score = match_score(edge.fact) + match_score(edge.name) if score > 0: scored_edges.append((score, edge)) # 按分数排序 scored_edges.sort(key=lambda x: x[0], reverse=True) for score, edge in scored_edges[:limit]: if edge.fact: facts.append(edge.fact) edges_result.append({ "uuid": edge.uuid, "name": edge.name, "fact": edge.fact, "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, }) if scope in ["nodes", "both"]: # 获取所有节点并匹配 all_nodes = self.get_all_nodes(graph_id) scored_nodes = [] for node in all_nodes: score = match_score(node.name) + match_score(node.summary) if score > 0: scored_nodes.append((score, node)) scored_nodes.sort(key=lambda x: x[0], reverse=True) for score, node in scored_nodes[:limit]: nodes_result.append({ "uuid": node.uuid, "name": node.name, "labels": node.labels, "summary": node.summary, }) if node.summary: facts.append(f"[{node.name}]: {node.summary}") logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实") except Exception as e: logger.error(f"本地搜索失败: {str(e)}") return SearchResult( facts=facts, edges=edges_result, nodes=nodes_result, query=query, total_count=len(facts) ) def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ 获取图谱的所有节点(分页获取) Args: graph_id: 图谱ID Returns: 节点列表 """ logger.info(f"获取图谱 {graph_id} 的所有节点...") nodes = fetch_all_nodes(self.client, graph_id) result = [] for node in nodes: node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or "" result.append(NodeInfo( uuid=str(node_uuid) if node_uuid else "", name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {} )) logger.info(f"获取到 {len(result)} 个节点") return result def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]: """ 获取图谱的所有边(分页获取,包含时间信息) Args: graph_id: 图谱ID include_temporal: 是否包含时间信息(默认True) Returns: 边列表(包含created_at, valid_at, invalid_at, expired_at) """ logger.info(f"获取图谱 {graph_id} 的所有边...") edges = fetch_all_edges(self.client, graph_id) result = [] for edge in edges: edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or "" edge_info = EdgeInfo( uuid=str(edge_uuid) if edge_uuid else "", name=edge.name or "", fact=edge.fact or "", source_node_uuid=edge.source_node_uuid or "", target_node_uuid=edge.target_node_uuid or "" ) # 添加时间信息 if include_temporal: edge_info.created_at = getattr(edge, 'created_at', None) edge_info.valid_at = getattr(edge, 'valid_at', None) edge_info.invalid_at = getattr(edge, 'invalid_at', None) edge_info.expired_at = getattr(edge, 'expired_at', None) result.append(edge_info) logger.info(f"获取到 {len(result)} 条边") return result def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: """ 获取单个节点的详细信息 Args: node_uuid: 节点UUID Returns: 节点信息或None """ logger.info(f"获取节点详情: {node_uuid[:8]}...") try: node = self._call_with_retry( func=lambda: self.client.graph.node.get(uuid_=node_uuid), operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" ) if not node: return None return NodeInfo( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {} ) except Exception as e: logger.error(f"获取节点详情失败: {str(e)}") return None def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]: """ 获取节点相关的所有边 通过获取图谱所有边,然后过滤出与指定节点相关的边 Args: graph_id: 图谱ID node_uuid: 节点UUID Returns: 边列表 """ logger.info(f"获取节点 {node_uuid[:8]}... 的相关边") try: # 获取图谱所有边,然后过滤 all_edges = self.get_all_edges(graph_id) result = [] for edge in all_edges: # 检查边是否与指定节点相关(作为源或目标) if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid: result.append(edge) logger.info(f"找到 {len(result)} 条与节点相关的边") return result except Exception as e: logger.warning(f"获取节点边失败: {str(e)}") return [] def get_entities_by_type( self, graph_id: str, entity_type: str ) -> List[NodeInfo]: """ 按类型获取实体 Args: graph_id: 图谱ID entity_type: 实体类型(如 Student, PublicFigure 等) Returns: 符合类型的实体列表 """ logger.info(f"获取类型为 {entity_type} 的实体...") all_nodes = self.get_all_nodes(graph_id) filtered = [] for node in all_nodes: # 检查labels是否包含指定类型 if entity_type in node.labels: filtered.append(node) logger.info(f"找到 {len(filtered)} 个 {entity_type} 类型的实体") return filtered def get_entity_summary( self, graph_id: str, entity_name: str ) -> Dict[str, Any]: """ 获取指定实体的关系摘要 搜索与该实体相关的所有信息,并生成摘要 Args: graph_id: 图谱ID entity_name: 实体名称 Returns: 实体摘要信息 """ logger.info(f"获取实体 {entity_name} 的关系摘要...") # 先搜索该实体相关的信息 search_result = self.search_graph( graph_id=graph_id, query=entity_name, limit=20 ) # 尝试在所有节点中找到该实体 all_nodes = self.get_all_nodes(graph_id) entity_node = None for node in all_nodes: if node.name.lower() == entity_name.lower(): entity_node = node break related_edges = [] if entity_node: # 传入graph_id参数 related_edges = self.get_node_edges(graph_id, entity_node.uuid) return { "entity_name": entity_name, "entity_info": entity_node.to_dict() if entity_node else None, "related_facts": search_result.facts, "related_edges": [e.to_dict() for e in related_edges], "total_relations": len(related_edges) } def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: """ 获取图谱的统计信息 Args: graph_id: 图谱ID Returns: 统计信息 """ logger.info(f"获取图谱 {graph_id} 的统计信息...") nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) # 统计实体类型分布 entity_types = {} for node in nodes: for label in node.labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 # 统计关系类型分布 relation_types = {} for edge in edges: relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 return { "graph_id": graph_id, "total_nodes": len(nodes), "total_edges": len(edges), "entity_types": entity_types, "relation_types": relation_types } def get_simulation_context( self, graph_id: str, simulation_requirement: str, limit: int = 30 ) -> Dict[str, Any]: """ 获取模拟相关的上下文信息 综合搜索与模拟需求相关的所有信息 Args: graph_id: 图谱ID simulation_requirement: 模拟需求描述 limit: 每类信息的数量限制 Returns: 模拟上下文信息 """ logger.info(f"获取模拟上下文: {simulation_requirement[:50]}...") # 搜索与模拟需求相关的信息 search_result = self.search_graph( graph_id=graph_id, query=simulation_requirement, limit=limit ) # 获取图谱统计 stats = self.get_graph_statistics(graph_id) # 获取所有实体节点 all_nodes = self.get_all_nodes(graph_id) # 筛选有实际类型的实体(非纯Entity节点) entities = [] for node in all_nodes: custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]] if custom_labels: entities.append({ "name": node.name, "type": custom_labels[0], "summary": node.summary }) return { "simulation_requirement": simulation_requirement, "related_facts": search_result.facts, "graph_statistics": stats, "entities": entities[:limit], # 限制数量 "total_entities": len(entities) } # ========== 核心检索工具(优化后) ========== def insight_forge( self, graph_id: str, query: str, simulation_requirement: str, report_context: str = "", max_sub_queries: int = 5 ) -> InsightForgeResult: """ 【InsightForge - 深度洞察检索】 最强大的混合检索函数,自动分解问题并多维度检索: 1. 使用LLM将问题分解为多个子问题 2. 对每个子问题进行语义搜索 3. 提取相关实体并获取其详细信息 4. 追踪关系链 5. 整合所有结果,生成深度洞察 Args: graph_id: 图谱ID query: 用户问题 simulation_requirement: 模拟需求描述 report_context: 报告上下文(可选,用于更精准的子问题生成) max_sub_queries: 最大子问题数量 Returns: InsightForgeResult: 深度洞察检索结果 """ logger.info(f"InsightForge 深度洞察检索: {query[:50]}...") result = InsightForgeResult( query=query, simulation_requirement=simulation_requirement, sub_queries=[] ) # Step 1: 使用LLM生成子问题 sub_queries = self._generate_sub_queries( query=query, simulation_requirement=simulation_requirement, report_context=report_context, max_queries=max_sub_queries ) result.sub_queries = sub_queries logger.info(f"生成 {len(sub_queries)} 个子问题") # Step 2: 对每个子问题进行语义搜索 all_facts = [] all_edges = [] seen_facts = set() for sub_query in sub_queries: search_result = self.search_graph( graph_id=graph_id, query=sub_query, limit=15, scope="edges" ) for fact in search_result.facts: if fact not in seen_facts: all_facts.append(fact) seen_facts.add(fact) all_edges.extend(search_result.edges) # 对原始问题也进行搜索 main_search = self.search_graph( graph_id=graph_id, query=query, limit=20, scope="edges" ) for fact in main_search.facts: if fact not in seen_facts: all_facts.append(fact) seen_facts.add(fact) result.semantic_facts = all_facts result.total_facts = len(all_facts) # Step 3: 从边中提取相关实体UUID,只获取这些实体的信息(不获取全部节点) entity_uuids = set() for edge_data in all_edges: if isinstance(edge_data, dict): source_uuid = edge_data.get('source_node_uuid', '') target_uuid = edge_data.get('target_node_uuid', '') if source_uuid: entity_uuids.add(source_uuid) if target_uuid: entity_uuids.add(target_uuid) # 获取所有相关实体的详情(不限制数量,完整输出) entity_insights = [] node_map = {} # 用于后续关系链构建 for uuid in list(entity_uuids): # 处理所有实体,不截断 if not uuid: continue try: # 单独获取每个相关节点的信息 node = self.get_node_detail(uuid) if node: node_map[uuid] = node entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") # 获取该实体相关的所有事实(不截断) related_facts = [ f for f in all_facts if node.name.lower() in f.lower() ] entity_insights.append({ "uuid": node.uuid, "name": node.name, "type": entity_type, "summary": node.summary, "related_facts": related_facts # 完整输出,不截断 }) except Exception as e: logger.debug(f"获取节点 {uuid} 失败: {e}") continue result.entity_insights = entity_insights result.total_entities = len(entity_insights) # Step 4: 构建所有关系链(不限制数量) relationship_chains = [] for edge_data in all_edges: # 处理所有边,不截断 if isinstance(edge_data, dict): source_uuid = edge_data.get('source_node_uuid', '') target_uuid = edge_data.get('target_node_uuid', '') relation_name = edge_data.get('name', '') source_name = node_map.get(source_uuid, NodeInfo('', '', [], '', {})).name or source_uuid[:8] target_name = node_map.get(target_uuid, NodeInfo('', '', [], '', {})).name or target_uuid[:8] chain = f"{source_name} --[{relation_name}]--> {target_name}" if chain not in relationship_chains: relationship_chains.append(chain) result.relationship_chains = relationship_chains result.total_relationships = len(relationship_chains) logger.info(f"InsightForge完成: {result.total_facts}条事实, {result.total_entities}个实体, {result.total_relationships}条关系") return result def _generate_sub_queries( self, query: str, simulation_requirement: str, report_context: str = "", max_queries: int = 5 ) -> List[str]: """ 使用LLM生成子问题 将复杂问题分解为多个可以独立检索的子问题 """ system_prompt = """你是一个专业的问题分析专家。你的任务是将一个复杂问题分解为多个可以在模拟世界中独立观察的子问题。 要求: 1. 每个子问题应该足够具体,可以在模拟世界中找到相关的Agent行为或事件 2. 子问题应该覆盖原问题的不同维度(如:谁、什么、为什么、怎么样、何时、何地) 3. 子问题应该与模拟场景相关 4. 返回JSON格式:{"sub_queries": ["子问题1", "子问题2", ...]}""" user_prompt = f"""模拟需求背景: {simulation_requirement} {f"报告上下文:{report_context[:500]}" if report_context else ""} 请将以下问题分解为{max_queries}个子问题: {query} 返回JSON格式的子问题列表。""" try: response = self.llm.chat_json( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], temperature=0.3 ) sub_queries = response.get("sub_queries", []) # 确保是字符串列表 return [str(sq) for sq in sub_queries[:max_queries]] except Exception as e: logger.warning(f"生成子问题失败: {str(e)},使用默认子问题") # 降级:返回基于原问题的变体 return [ query, f"{query} 的主要参与者", f"{query} 的原因和影响", f"{query} 的发展过程" ][:max_queries] def panorama_search( self, graph_id: str, query: str, include_expired: bool = True, limit: int = 50 ) -> PanoramaResult: """ 【PanoramaSearch - 广度搜索】 获取全貌视图,包括所有相关内容和历史/过期信息: 1. 获取所有相关节点 2. 获取所有边(包括已过期/失效的) 3. 分类整理当前有效和历史信息 这个工具适用于需要了解事件全貌、追踪演变过程的场景。 Args: graph_id: 图谱ID query: 搜索查询(用于相关性排序) include_expired: 是否包含过期内容(默认True) limit: 返回结果数量限制 Returns: PanoramaResult: 广度搜索结果 """ logger.info(f"PanoramaSearch 广度搜索: {query[:50]}...") result = PanoramaResult(query=query) # 获取所有节点 all_nodes = self.get_all_nodes(graph_id) node_map = {n.uuid: n for n in all_nodes} result.all_nodes = all_nodes result.total_nodes = len(all_nodes) # 获取所有边(包含时间信息) all_edges = self.get_all_edges(graph_id, include_temporal=True) result.all_edges = all_edges result.total_edges = len(all_edges) # 分类事实 active_facts = [] historical_facts = [] for edge in all_edges: if not edge.fact: continue # 为事实添加实体名称 source_name = node_map.get(edge.source_node_uuid, NodeInfo('', '', [], '', {})).name or edge.source_node_uuid[:8] target_name = node_map.get(edge.target_node_uuid, NodeInfo('', '', [], '', {})).name or edge.target_node_uuid[:8] # 判断是否过期/失效 is_historical = edge.is_expired or edge.is_invalid if is_historical: # 历史/过期事实,添加时间标记 valid_at = edge.valid_at or "未知" invalid_at = edge.invalid_at or edge.expired_at or "未知" fact_with_time = f"[{valid_at} - {invalid_at}] {edge.fact}" historical_facts.append(fact_with_time) else: # 当前有效事实 active_facts.append(edge.fact) # 基于查询进行相关性排序 query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] def relevance_score(fact: str) -> int: fact_lower = fact.lower() score = 0 if query_lower in fact_lower: score += 100 for kw in keywords: if kw in fact_lower: score += 10 return score # 排序并限制数量 active_facts.sort(key=relevance_score, reverse=True) historical_facts.sort(key=relevance_score, reverse=True) result.active_facts = active_facts[:limit] result.historical_facts = historical_facts[:limit] if include_expired else [] result.active_count = len(active_facts) result.historical_count = len(historical_facts) logger.info(f"PanoramaSearch完成: {result.active_count}条有效, {result.historical_count}条历史") return result def quick_search( self, graph_id: str, query: str, limit: int = 10 ) -> SearchResult: """ 【QuickSearch - 简单搜索】 快速、轻量级的检索工具: 1. 直接调用Zep语义搜索 2. 返回最相关的结果 3. 适用于简单、直接的检索需求 Args: graph_id: 图谱ID query: 搜索查询 limit: 返回结果数量 Returns: SearchResult: 搜索结果 """ logger.info(f"QuickSearch 简单搜索: {query[:50]}...") # 直接调用现有的search_graph方法 result = self.search_graph( graph_id=graph_id, query=query, limit=limit, scope="edges" ) logger.info(f"QuickSearch完成: {result.total_count}条结果") return result def interview_agents( self, simulation_id: str, interview_requirement: str, simulation_requirement: str = "", max_agents: int = 5, custom_questions: List[str] = None ) -> InterviewResult: """ 【InterviewAgents - 深度采访】 调用真实的OASIS采访API,采访模拟中正在运行的Agent: 1. 自动读取人设文件,了解所有模拟Agent 2. 使用LLM分析采访需求,智能选择最相关的Agent 3. 使用LLM生成采访问题 4. 调用 /api/simulation/interview/batch 接口进行真实采访(双平台同时采访) 5. 整合所有采访结果,生成采访报告 【重要】此功能需要模拟环境处于运行状态(OASIS环境未关闭) 【使用场景】 - 需要从不同角色视角了解事件看法 - 需要收集多方意见和观点 - 需要获取模拟Agent的真实回答(非LLM模拟) Args: simulation_id: 模拟ID(用于定位人设文件和调用采访API) interview_requirement: 采访需求描述(非结构化,如"了解学生对事件的看法") simulation_requirement: 模拟需求背景(可选) max_agents: 最多采访的Agent数量 custom_questions: 自定义采访问题(可选,若不提供则自动生成) Returns: InterviewResult: 采访结果 """ from .simulation_runner import SimulationRunner logger.info(f"InterviewAgents 深度采访(真实API): {interview_requirement[:50]}...") result = InterviewResult( interview_topic=interview_requirement, interview_questions=custom_questions or [] ) # Step 1: 读取人设文件 profiles = self._load_agent_profiles(simulation_id) if not profiles: logger.warning(f"未找到模拟 {simulation_id} 的人设文件") result.summary = "未找到可采访的Agent人设文件" return result result.total_agents = len(profiles) logger.info(f"加载到 {len(profiles)} 个Agent人设") # Step 2: 使用LLM选择要采访的Agent(返回agent_id列表) selected_agents, selected_indices, selection_reasoning = self._select_agents_for_interview( profiles=profiles, interview_requirement=interview_requirement, simulation_requirement=simulation_requirement, max_agents=max_agents ) result.selected_agents = selected_agents result.selection_reasoning = selection_reasoning logger.info(f"选择了 {len(selected_agents)} 个Agent进行采访: {selected_indices}") # Step 3: 生成采访问题(如果没有提供) if not result.interview_questions: result.interview_questions = self._generate_interview_questions( interview_requirement=interview_requirement, simulation_requirement=simulation_requirement, selected_agents=selected_agents ) logger.info(f"生成了 {len(result.interview_questions)} 个采访问题") # 将问题合并为一个采访prompt combined_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(result.interview_questions)]) # 添加优化前缀,约束Agent回复格式 INTERVIEW_PROMPT_PREFIX = ( "你正在接受一次采访。请结合你的人设、所有的过往记忆与行动," "以纯文本方式直接回答以下问题。\n" "回复要求:\n" "1. 直接用自然语言回答,不要调用任何工具\n" "2. 不要返回JSON格式或工具调用格式\n" "3. 不要使用Markdown标题(如#、##、###)\n" "4. 按问题编号逐一回答,每个回答以「问题X:」开头(X为问题编号)\n" "5. 每个问题的回答之间用空行分隔\n" "6. 回答要有实质内容,每个问题至少回答2-3句话\n\n" ) optimized_prompt = f"{INTERVIEW_PROMPT_PREFIX}{combined_prompt}" # Step 4: 调用真实的采访API(不指定platform,默认双平台同时采访) try: # 构建批量采访列表(不指定platform,双平台采访) interviews_request = [] for agent_idx in selected_indices: interviews_request.append({ "agent_id": agent_idx, "prompt": optimized_prompt # 使用优化后的prompt # 不指定platform,API会在twitter和reddit两个平台都采访 }) logger.info(f"调用批量采访API(双平台): {len(interviews_request)} 个Agent") # 调用 SimulationRunner 的批量采访方法(不传platform,双平台采访) api_result = SimulationRunner.interview_agents_batch( simulation_id=simulation_id, interviews=interviews_request, platform=None, # 不指定platform,双平台采访 timeout=180.0 # 双平台需要更长超时 ) logger.info(f"采访API返回: {api_result.get('interviews_count', 0)} 个结果, success={api_result.get('success')}") # 检查API调用是否成功 if not api_result.get("success", False): error_msg = api_result.get("error", "未知错误") logger.warning(f"采访API返回失败: {error_msg}") result.summary = f"采访API调用失败:{error_msg}。请检查OASIS模拟环境状态。" return result # Step 5: 解析API返回结果,构建AgentInterview对象 # 双平台模式返回格式: {"twitter_0": {...}, "reddit_0": {...}, "twitter_1": {...}, ...} api_data = api_result.get("result", {}) results_dict = api_data.get("results", {}) if isinstance(api_data, dict) else {} for i, agent_idx in enumerate(selected_indices): agent = selected_agents[i] agent_name = agent.get("realname", agent.get("username", f"Agent_{agent_idx}")) agent_role = agent.get("profession", "未知") agent_bio = agent.get("bio", "") # 获取该Agent在两个平台的采访结果 twitter_result = results_dict.get(f"twitter_{agent_idx}", {}) reddit_result = results_dict.get(f"reddit_{agent_idx}", {}) twitter_response = twitter_result.get("response", "") reddit_response = reddit_result.get("response", "") # 清理可能的工具调用 JSON 包裹 twitter_response = self._clean_tool_call_response(twitter_response) reddit_response = self._clean_tool_call_response(reddit_response) # 始终输出双平台标记 twitter_text = twitter_response if twitter_response else "(该平台未获得回复)" reddit_text = reddit_response if reddit_response else "(该平台未获得回复)" response_text = f"【Twitter平台回答】\n{twitter_text}\n\n【Reddit平台回答】\n{reddit_text}" # 提取关键引言(从两个平台的回答中) import re combined_responses = f"{twitter_response} {reddit_response}" # 清理响应文本:去掉标记、编号、Markdown 等干扰 clean_text = re.sub(r'#{1,6}\s+', '', combined_responses) clean_text = re.sub(r'\{[^}]*tool_name[^}]*\}', '', clean_text) clean_text = re.sub(r'[*_`|>~\-]{2,}', '', clean_text) clean_text = re.sub(r'问题\d+[::]\s*', '', clean_text) clean_text = re.sub(r'【[^】]+】', '', clean_text) # 策略1(主): 提取完整的有实质内容的句子 sentences = re.split(r'[。!?]', clean_text) meaningful = [ s.strip() for s in sentences if 20 <= len(s.strip()) <= 150 and not re.match(r'^[\s\W,,;;::、]+', s.strip()) and not s.strip().startswith(('{', '问题')) ] meaningful.sort(key=len, reverse=True) key_quotes = [s + "。" for s in meaningful[:3]] # 策略2(补充): 正确配对的中文引号「」内长文本 if not key_quotes: paired = re.findall(r'\u201c([^\u201c\u201d]{15,100})\u201d', clean_text) paired += re.findall(r'\u300c([^\u300c\u300d]{15,100})\u300d', clean_text) key_quotes = [q for q in paired if not re.match(r'^[,,;;::、]', q)][:3] interview = AgentInterview( agent_name=agent_name, agent_role=agent_role, agent_bio=agent_bio[:1000], # 扩大bio长度限制 question=combined_prompt, response=response_text, key_quotes=key_quotes[:5] ) result.interviews.append(interview) result.interviewed_count = len(result.interviews) except ValueError as e: # 模拟环境未运行 logger.warning(f"采访API调用失败(环境未运行?): {e}") result.summary = f"采访失败:{str(e)}。模拟环境可能已关闭,请确保OASIS环境正在运行。" return result except Exception as e: logger.error(f"采访API调用异常: {e}") import traceback logger.error(traceback.format_exc()) result.summary = f"采访过程发生错误:{str(e)}" return result # Step 6: 生成采访摘要 if result.interviews: result.summary = self._generate_interview_summary( interviews=result.interviews, interview_requirement=interview_requirement ) logger.info(f"InterviewAgents完成: 采访了 {result.interviewed_count} 个Agent(双平台)") return result @staticmethod def _clean_tool_call_response(response: str) -> str: """清理 Agent 回复中的 JSON 工具调用包裹,提取实际内容""" if not response or not response.strip().startswith('{'): return response text = response.strip() if 'tool_name' not in text[:80]: return response import re as _re try: data = json.loads(text) if isinstance(data, dict) and 'arguments' in data: for key in ('content', 'text', 'body', 'message', 'reply'): if key in data['arguments']: return str(data['arguments'][key]) except (json.JSONDecodeError, KeyError, TypeError): match = _re.search(r'"content"\s*:\s*"((?:[^"\\]|\\.)*)"', text) if match: return match.group(1).replace('\\n', '\n').replace('\\"', '"') return response def _load_agent_profiles(self, simulation_id: str) -> List[Dict[str, Any]]: """加载模拟的Agent人设文件""" import os import csv # 构建人设文件路径 sim_dir = os.path.join( os.path.dirname(__file__), f'../../uploads/simulations/{simulation_id}' ) profiles = [] # 优先尝试读取Reddit JSON格式 reddit_profile_path = os.path.join(sim_dir, "reddit_profiles.json") if os.path.exists(reddit_profile_path): try: with open(reddit_profile_path, 'r', encoding='utf-8') as f: profiles = json.load(f) logger.info(f"从 reddit_profiles.json 加载了 {len(profiles)} 个人设") return profiles except Exception as e: logger.warning(f"读取 reddit_profiles.json 失败: {e}") # 尝试读取Twitter CSV格式 twitter_profile_path = os.path.join(sim_dir, "twitter_profiles.csv") if os.path.exists(twitter_profile_path): try: with open(twitter_profile_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: # CSV格式转换为统一格式 profiles.append({ "realname": row.get("name", ""), "username": row.get("username", ""), "bio": row.get("description", ""), "persona": row.get("user_char", ""), "profession": "未知" }) logger.info(f"从 twitter_profiles.csv 加载了 {len(profiles)} 个人设") return profiles except Exception as e: logger.warning(f"读取 twitter_profiles.csv 失败: {e}") return profiles def _select_agents_for_interview( self, profiles: List[Dict[str, Any]], interview_requirement: str, simulation_requirement: str, max_agents: int ) -> tuple: """ 使用LLM选择要采访的Agent Returns: tuple: (selected_agents, selected_indices, reasoning) - selected_agents: 选中Agent的完整信息列表 - selected_indices: 选中Agent的索引列表(用于API调用) - reasoning: 选择理由 """ # 构建Agent摘要列表 agent_summaries = [] for i, profile in enumerate(profiles): summary = { "index": i, "name": profile.get("realname", profile.get("username", f"Agent_{i}")), "profession": profile.get("profession", "未知"), "bio": profile.get("bio", "")[:200], "interested_topics": profile.get("interested_topics", []) } agent_summaries.append(summary) system_prompt = """你是一个专业的采访策划专家。你的任务是根据采访需求,从模拟Agent列表中选择最适合采访的对象。 选择标准: 1. Agent的身份/职业与采访主题相关 2. Agent可能持有独特或有价值的观点 3. 选择多样化的视角(如:支持方、反对方、中立方、专业人士等) 4. 优先选择与事件直接相关的角色 返回JSON格式: { "selected_indices": [选中Agent的索引列表], "reasoning": "选择理由说明" }""" user_prompt = f"""采访需求: {interview_requirement} 模拟背景: {simulation_requirement if simulation_requirement else "未提供"} 可选择的Agent列表(共{len(agent_summaries)}个): {json.dumps(agent_summaries, ensure_ascii=False, indent=2)} 请选择最多{max_agents}个最适合采访的Agent,并说明选择理由。""" try: response = self.llm.chat_json( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], temperature=0.3 ) selected_indices = response.get("selected_indices", [])[:max_agents] reasoning = response.get("reasoning", "基于相关性自动选择") # 获取选中的Agent完整信息 selected_agents = [] valid_indices = [] for idx in selected_indices: if 0 <= idx < len(profiles): selected_agents.append(profiles[idx]) valid_indices.append(idx) return selected_agents, valid_indices, reasoning except Exception as e: logger.warning(f"LLM选择Agent失败,使用默认选择: {e}") # 降级:选择前N个 selected = profiles[:max_agents] indices = list(range(min(max_agents, len(profiles)))) return selected, indices, "使用默认选择策略" def _generate_interview_questions( self, interview_requirement: str, simulation_requirement: str, selected_agents: List[Dict[str, Any]] ) -> List[str]: """使用LLM生成采访问题""" agent_roles = [a.get("profession", "未知") for a in selected_agents] system_prompt = """你是一个专业的记者/采访者。根据采访需求,生成3-5个深度采访问题。 问题要求: 1. 开放性问题,鼓励详细回答 2. 针对不同角色可能有不同答案 3. 涵盖事实、观点、感受等多个维度 4. 语言自然,像真实采访一样 5. 每个问题控制在50字以内,简洁明了 6. 直接提问,不要包含背景说明或前缀 返回JSON格式:{"questions": ["问题1", "问题2", ...]}""" user_prompt = f"""采访需求:{interview_requirement} 模拟背景:{simulation_requirement if simulation_requirement else "未提供"} 采访对象角色:{', '.join(agent_roles)} 请生成3-5个采访问题。""" try: response = self.llm.chat_json( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], temperature=0.5 ) return response.get("questions", [f"关于{interview_requirement},您有什么看法?"]) except Exception as e: logger.warning(f"生成采访问题失败: {e}") return [ f"关于{interview_requirement},您的观点是什么?", "这件事对您或您所代表的群体有什么影响?", "您认为应该如何解决或改进这个问题?" ] def _generate_interview_summary( self, interviews: List[AgentInterview], interview_requirement: str ) -> str: """生成采访摘要""" if not interviews: return "未完成任何采访" # 收集所有采访内容 interview_texts = [] for interview in interviews: interview_texts.append(f"【{interview.agent_name}({interview.agent_role})】\n{interview.response[:500]}") system_prompt = """你是一个专业的新闻编辑。请根据多位受访者的回答,生成一份采访摘要。 摘要要求: 1. 提炼各方主要观点 2. 指出观点的共识和分歧 3. 突出有价值的引言 4. 客观中立,不偏袒任何一方 5. 控制在1000字内 格式约束(必须遵守): - 使用纯文本段落,用空行分隔不同部分 - 不要使用Markdown标题(如#、##、###) - 不要使用分割线(如---、***) - 引用受访者原话时使用中文引号「」 - 可以使用**加粗**标记关键词,但不要使用其他Markdown语法""" user_prompt = f"""采访主题:{interview_requirement} 采访内容: {"".join(interview_texts)} 请生成采访摘要。""" try: summary = self.llm.chat( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], temperature=0.3, max_tokens=800 ) return summary except Exception as e: logger.warning(f"生成采访摘要失败: {e}") # 降级:简单拼接 return f"共采访了{len(interviews)}位受访者,包括:" + "、".join([i.agent_name for i in interviews]) ================================================ FILE: backend/app/utils/__init__.py ================================================ """ 工具模块 """ from .file_parser import FileParser from .llm_client import LLMClient __all__ = ['FileParser', 'LLMClient'] ================================================ FILE: backend/app/utils/file_parser.py ================================================ """ 文件解析工具 支持PDF、Markdown、TXT文件的文本提取 """ import os from pathlib import Path from typing import List, Optional def _read_text_with_fallback(file_path: str) -> str: """ 读取文本文件,UTF-8失败时自动探测编码。 采用多级回退策略: 1. 首先尝试 UTF-8 解码 2. 使用 charset_normalizer 检测编码 3. 回退到 chardet 检测编码 4. 最终使用 UTF-8 + errors='replace' 兜底 Args: file_path: 文件路径 Returns: 解码后的文本内容 """ data = Path(file_path).read_bytes() # 首先尝试 UTF-8 try: return data.decode('utf-8') except UnicodeDecodeError: pass # 尝试使用 charset_normalizer 检测编码 encoding = None try: from charset_normalizer import from_bytes best = from_bytes(data).best() if best and best.encoding: encoding = best.encoding except Exception: pass # 回退到 chardet if not encoding: try: import chardet result = chardet.detect(data) encoding = result.get('encoding') if result else None except Exception: pass # 最终兜底:使用 UTF-8 + replace if not encoding: encoding = 'utf-8' return data.decode(encoding, errors='replace') class FileParser: """文件解析器""" SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'} @classmethod def extract_text(cls, file_path: str) -> str: """ 从文件中提取文本 Args: file_path: 文件路径 Returns: 提取的文本内容 """ path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"文件不存在: {file_path}") suffix = path.suffix.lower() if suffix not in cls.SUPPORTED_EXTENSIONS: raise ValueError(f"不支持的文件格式: {suffix}") if suffix == '.pdf': return cls._extract_from_pdf(file_path) elif suffix in {'.md', '.markdown'}: return cls._extract_from_md(file_path) elif suffix == '.txt': return cls._extract_from_txt(file_path) raise ValueError(f"无法处理的文件格式: {suffix}") @staticmethod def _extract_from_pdf(file_path: str) -> str: """从PDF提取文本""" try: import fitz # PyMuPDF except ImportError: raise ImportError("需要安装PyMuPDF: pip install PyMuPDF") text_parts = [] with fitz.open(file_path) as doc: for page in doc: text = page.get_text() if text.strip(): text_parts.append(text) return "\n\n".join(text_parts) @staticmethod def _extract_from_md(file_path: str) -> str: """从Markdown提取文本,支持自动编码检测""" return _read_text_with_fallback(file_path) @staticmethod def _extract_from_txt(file_path: str) -> str: """从TXT提取文本,支持自动编码检测""" return _read_text_with_fallback(file_path) @classmethod def extract_from_multiple(cls, file_paths: List[str]) -> str: """ 从多个文件提取文本并合并 Args: file_paths: 文件路径列表 Returns: 合并后的文本 """ all_texts = [] for i, file_path in enumerate(file_paths, 1): try: text = cls.extract_text(file_path) filename = Path(file_path).name all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}") except Exception as e: all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===") return "\n\n".join(all_texts) def split_text_into_chunks( text: str, chunk_size: int = 500, overlap: int = 50 ) -> List[str]: """ 将文本分割成小块 Args: text: 原始文本 chunk_size: 每块的字符数 overlap: 重叠字符数 Returns: 文本块列表 """ if len(text) <= chunk_size: return [text] if text.strip() else [] chunks = [] start = 0 while start < len(text): end = start + chunk_size # 尝试在句子边界处分割 if end < len(text): # 查找最近的句子结束符 for sep in ['。', '!', '?', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']: last_sep = text[start:end].rfind(sep) if last_sep != -1 and last_sep > chunk_size * 0.3: end = start + last_sep + len(sep) break chunk = text[start:end].strip() if chunk: chunks.append(chunk) # 下一个块从重叠位置开始 start = end - overlap if end < len(text) else len(text) return chunks ================================================ FILE: backend/app/utils/llm_client.py ================================================ """ LLM客户端封装 统一使用OpenAI格式调用 """ import json import re from typing import Optional, Dict, Any, List from openai import OpenAI from ..config import Config class LLMClient: """LLM客户端""" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model = model or Config.LLM_MODEL_NAME if not self.api_key: raise ValueError("LLM_API_KEY 未配置") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) def chat( self, messages: List[Dict[str, str]], temperature: float = 0.7, max_tokens: int = 4096, response_format: Optional[Dict] = None ) -> str: """ 发送聊天请求 Args: messages: 消息列表 temperature: 温度参数 max_tokens: 最大token数 response_format: 响应格式(如JSON模式) Returns: 模型响应文本 """ kwargs = { "model": self.model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } if response_format: kwargs["response_format"] = response_format response = self.client.chat.completions.create(**kwargs) content = response.choices[0].message.content # 部分模型(如MiniMax M2.5)会在content中包含思考内容,需要移除 content = re.sub(r'[\s\S]*?', '', content).strip() return content def chat_json( self, messages: List[Dict[str, str]], temperature: float = 0.3, max_tokens: int = 4096 ) -> Dict[str, Any]: """ 发送聊天请求并返回JSON Args: messages: 消息列表 temperature: 温度参数 max_tokens: 最大token数 Returns: 解析后的JSON对象 """ response = self.chat( messages=messages, temperature=temperature, max_tokens=max_tokens, response_format={"type": "json_object"} ) # 清理markdown代码块标记 cleaned_response = response.strip() cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) cleaned_response = cleaned_response.strip() try: return json.loads(cleaned_response) except json.JSONDecodeError: raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") ================================================ FILE: backend/app/utils/logger.py ================================================ """ 日志配置模块 提供统一的日志管理,同时输出到控制台和文件 """ import os import sys import logging from datetime import datetime from logging.handlers import RotatingFileHandler def _ensure_utf8_stdout(): """ 确保 stdout/stderr 使用 UTF-8 编码 解决 Windows 控制台中文乱码问题 """ if sys.platform == 'win32': # Windows 下重新配置标准输出为 UTF-8 if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') # 日志目录 LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs') def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger: """ 设置日志器 Args: name: 日志器名称 level: 日志级别 Returns: 配置好的日志器 """ # 确保日志目录存在 os.makedirs(LOG_DIR, exist_ok=True) # 创建日志器 logger = logging.getLogger(name) logger.setLevel(level) # 阻止日志向上传播到根 logger,避免重复输出 logger.propagate = False # 如果已经有处理器,不重复添加 if logger.handlers: return logger # 日志格式 detailed_formatter = logging.Formatter( '[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) simple_formatter = logging.Formatter( '[%(asctime)s] %(levelname)s: %(message)s', datefmt='%H:%M:%S' ) # 1. 文件处理器 - 详细日志(按日期命名,带轮转) log_filename = datetime.now().strftime('%Y-%m-%d') + '.log' file_handler = RotatingFileHandler( os.path.join(LOG_DIR, log_filename), maxBytes=10 * 1024 * 1024, # 10MB backupCount=5, encoding='utf-8' ) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(detailed_formatter) # 2. 控制台处理器 - 简洁日志(INFO及以上) # 确保 Windows 下使用 UTF-8 编码,避免中文乱码 _ensure_utf8_stdout() console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_handler.setFormatter(simple_formatter) # 添加处理器 logger.addHandler(file_handler) logger.addHandler(console_handler) return logger def get_logger(name: str = 'mirofish') -> logging.Logger: """ 获取日志器(如果不存在则创建) Args: name: 日志器名称 Returns: 日志器实例 """ logger = logging.getLogger(name) if not logger.handlers: return setup_logger(name) return logger # 创建默认日志器 logger = setup_logger() # 便捷方法 def debug(msg, *args, **kwargs): logger.debug(msg, *args, **kwargs) def info(msg, *args, **kwargs): logger.info(msg, *args, **kwargs) def warning(msg, *args, **kwargs): logger.warning(msg, *args, **kwargs) def error(msg, *args, **kwargs): logger.error(msg, *args, **kwargs) def critical(msg, *args, **kwargs): logger.critical(msg, *args, **kwargs) ================================================ FILE: backend/app/utils/retry.py ================================================ """ API调用重试机制 用于处理LLM等外部API调用的重试逻辑 """ import time import random import functools from typing import Callable, Any, Optional, Type, Tuple from ..utils.logger import get_logger logger = get_logger('mirofish.retry') def retry_with_backoff( max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 30.0, backoff_factor: float = 2.0, jitter: bool = True, exceptions: Tuple[Type[Exception], ...] = (Exception,), on_retry: Optional[Callable[[Exception, int], None]] = None ): """ 带指数退避的重试装饰器 Args: max_retries: 最大重试次数 initial_delay: 初始延迟(秒) max_delay: 最大延迟(秒) backoff_factor: 退避因子 jitter: 是否添加随机抖动 exceptions: 需要重试的异常类型 on_retry: 重试时的回调函数 (exception, retry_count) Usage: @retry_with_backoff(max_retries=3) def call_llm_api(): ... """ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: last_exception = None delay = initial_delay for attempt in range(max_retries + 1): try: return func(*args, **kwargs) except exceptions as e: last_exception = e if attempt == max_retries: logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}") raise # 计算延迟 current_delay = min(delay, max_delay) if jitter: current_delay = current_delay * (0.5 + random.random()) logger.warning( f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, " f"{current_delay:.1f}秒后重试..." ) if on_retry: on_retry(e, attempt + 1) time.sleep(current_delay) delay *= backoff_factor raise last_exception return wrapper return decorator def retry_with_backoff_async( max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 30.0, backoff_factor: float = 2.0, jitter: bool = True, exceptions: Tuple[Type[Exception], ...] = (Exception,), on_retry: Optional[Callable[[Exception, int], None]] = None ): """ 异步版本的重试装饰器 """ import asyncio def decorator(func: Callable) -> Callable: @functools.wraps(func) async def wrapper(*args, **kwargs) -> Any: last_exception = None delay = initial_delay for attempt in range(max_retries + 1): try: return await func(*args, **kwargs) except exceptions as e: last_exception = e if attempt == max_retries: logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}") raise current_delay = min(delay, max_delay) if jitter: current_delay = current_delay * (0.5 + random.random()) logger.warning( f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, " f"{current_delay:.1f}秒后重试..." ) if on_retry: on_retry(e, attempt + 1) await asyncio.sleep(current_delay) delay *= backoff_factor raise last_exception return wrapper return decorator class RetryableAPIClient: """ 可重试的API客户端封装 """ def __init__( self, max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 30.0, backoff_factor: float = 2.0 ): self.max_retries = max_retries self.initial_delay = initial_delay self.max_delay = max_delay self.backoff_factor = backoff_factor def call_with_retry( self, func: Callable, *args, exceptions: Tuple[Type[Exception], ...] = (Exception,), **kwargs ) -> Any: """ 执行函数调用并在失败时重试 Args: func: 要调用的函数 *args: 函数参数 exceptions: 需要重试的异常类型 **kwargs: 函数关键字参数 Returns: 函数返回值 """ last_exception = None delay = self.initial_delay for attempt in range(self.max_retries + 1): try: return func(*args, **kwargs) except exceptions as e: last_exception = e if attempt == self.max_retries: logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}") raise current_delay = min(delay, self.max_delay) current_delay = current_delay * (0.5 + random.random()) logger.warning( f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, " f"{current_delay:.1f}秒后重试..." ) time.sleep(current_delay) delay *= self.backoff_factor raise last_exception def call_batch_with_retry( self, items: list, process_func: Callable, exceptions: Tuple[Type[Exception], ...] = (Exception,), continue_on_failure: bool = True ) -> Tuple[list, list]: """ 批量调用并对每个失败项单独重试 Args: items: 要处理的项目列表 process_func: 处理函数,接收单个item作为参数 exceptions: 需要重试的异常类型 continue_on_failure: 单项失败后是否继续处理其他项 Returns: (成功结果列表, 失败项列表) """ results = [] failures = [] for idx, item in enumerate(items): try: result = self.call_with_retry( process_func, item, exceptions=exceptions ) results.append(result) except Exception as e: logger.error(f"处理第 {idx + 1} 项失败: {str(e)}") failures.append({ "index": idx, "item": item, "error": str(e) }) if not continue_on_failure: raise return results, failures ================================================ FILE: backend/app/utils/zep_paging.py ================================================ """Zep Graph 分页读取工具。 Zep 的 node/edge 列表接口使用 UUID cursor 分页, 本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。 """ from __future__ import annotations import time from collections.abc import Callable from typing import Any from zep_cloud import InternalServerError from zep_cloud.client import Zep from .logger import get_logger logger = get_logger('mirofish.zep_paging') _DEFAULT_PAGE_SIZE = 100 _MAX_NODES = 2000 _DEFAULT_MAX_RETRIES = 3 _DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry def _fetch_page_with_retry( api_call: Callable[..., list[Any]], *args: Any, max_retries: int = _DEFAULT_MAX_RETRIES, retry_delay: float = _DEFAULT_RETRY_DELAY, page_description: str = "page", **kwargs: Any, ) -> list[Any]: """单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。""" if max_retries < 1: raise ValueError("max_retries must be >= 1") last_exception: Exception | None = None delay = retry_delay for attempt in range(max_retries): try: return api_call(*args, **kwargs) except (ConnectionError, TimeoutError, OSError, InternalServerError) as e: last_exception = e if attempt < max_retries - 1: logger.warning( f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..." ) time.sleep(delay) delay *= 2 else: logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}") assert last_exception is not None raise last_exception def fetch_all_nodes( client: Zep, graph_id: str, page_size: int = _DEFAULT_PAGE_SIZE, max_items: int = _MAX_NODES, max_retries: int = _DEFAULT_MAX_RETRIES, retry_delay: float = _DEFAULT_RETRY_DELAY, ) -> list[Any]: """分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。""" all_nodes: list[Any] = [] cursor: str | None = None page_num = 0 while True: kwargs: dict[str, Any] = {"limit": page_size} if cursor is not None: kwargs["uuid_cursor"] = cursor page_num += 1 batch = _fetch_page_with_retry( client.graph.node.get_by_graph_id, graph_id, max_retries=max_retries, retry_delay=retry_delay, page_description=f"fetch nodes page {page_num} (graph={graph_id})", **kwargs, ) if not batch: break all_nodes.extend(batch) if len(all_nodes) >= max_items: all_nodes = all_nodes[:max_items] logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}") break if len(batch) < page_size: break cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None) if cursor is None: logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes") break return all_nodes def fetch_all_edges( client: Zep, graph_id: str, page_size: int = _DEFAULT_PAGE_SIZE, max_retries: int = _DEFAULT_MAX_RETRIES, retry_delay: float = _DEFAULT_RETRY_DELAY, ) -> list[Any]: """分页获取图谱所有边,返回完整列表。每页请求自带重试。""" all_edges: list[Any] = [] cursor: str | None = None page_num = 0 while True: kwargs: dict[str, Any] = {"limit": page_size} if cursor is not None: kwargs["uuid_cursor"] = cursor page_num += 1 batch = _fetch_page_with_retry( client.graph.edge.get_by_graph_id, graph_id, max_retries=max_retries, retry_delay=retry_delay, page_description=f"fetch edges page {page_num} (graph={graph_id})", **kwargs, ) if not batch: break all_edges.extend(batch) if len(batch) < page_size: break cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None) if cursor is None: logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges") break return all_edges ================================================ FILE: backend/pyproject.toml ================================================ [project] name = "mirofish-backend" version = "0.1.0" description = "MiroFish - 简洁通用的群体智能引擎,预测万物" requires-python = ">=3.11" license = { text = "AGPL-3.0" } authors = [ { name = "MiroFish Team" } ] dependencies = [ # 核心框架 "flask>=3.0.0", "flask-cors>=6.0.0", # LLM 相关 "openai>=1.0.0", # Zep Cloud "zep-cloud==3.13.0", # OASIS 社交媒体模拟 "camel-oasis==0.2.5", "camel-ai==0.2.78", # 文件处理 "PyMuPDF>=1.24.0", # 编码检测(支持非UTF-8编码的文本文件) "charset-normalizer>=3.0.0", "chardet>=5.0.0", # 工具库 "python-dotenv>=1.0.0", "pydantic>=2.0.0", ] [project.optional-dependencies] dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", "pipreqs>=0.5.0", ] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [dependency-groups] dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", ] [tool.hatch.build.targets.wheel] packages = ["app"] ================================================ FILE: backend/requirements.txt ================================================ # =========================================== # MiroFish Backend Dependencies # =========================================== # Python 3.11+ required # Install: pip install -r requirements.txt # =========================================== # ============= 核心框架 ============= flask>=3.0.0 flask-cors>=6.0.0 # ============= LLM 相关 ============= # OpenAI SDK(统一使用 OpenAI 格式调用 LLM) openai>=1.0.0 # ============= Zep Cloud ============= zep-cloud==3.13.0 # ============= OASIS 社交媒体模拟 ============= # OASIS 社交模拟框架 camel-oasis==0.2.5 camel-ai==0.2.78 # ============= 文件处理 ============= PyMuPDF>=1.24.0 # 编码检测(支持非UTF-8编码的文本文件) charset-normalizer>=3.0.0 chardet>=5.0.0 # ============= 工具库 ============= # 环境变量加载 python-dotenv>=1.0.0 # 数据验证 pydantic>=2.0.0 ================================================ FILE: backend/run.py ================================================ """ MiroFish Backend 启动入口 """ import os import sys # 解决 Windows 控制台中文乱码问题:在所有导入之前设置 UTF-8 编码 if sys.platform == 'win32': # 设置环境变量确保 Python 使用 UTF-8 os.environ.setdefault('PYTHONIOENCODING', 'utf-8') # 重新配置标准输出流为 UTF-8 if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') # 添加项目根目录到路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from app import create_app from app.config import Config def main(): """主函数""" # 验证配置 errors = Config.validate() if errors: print("配置错误:") for err in errors: print(f" - {err}") print("\n请检查 .env 文件中的配置") sys.exit(1) # 创建应用 app = create_app() # 获取运行配置 host = os.environ.get('FLASK_HOST', '0.0.0.0') port = int(os.environ.get('FLASK_PORT', 5001)) debug = Config.DEBUG # 启动服务 app.run(host=host, port=port, debug=debug, threaded=True) if __name__ == '__main__': main() ================================================ FILE: backend/scripts/action_logger.py ================================================ """ 动作日志记录器 用于记录OASIS模拟中每个Agent的动作,供后端监控使用 日志结构: sim_xxx/ ├── twitter/ │ └── actions.jsonl # Twitter 平台动作日志 ├── reddit/ │ └── actions.jsonl # Reddit 平台动作日志 ├── simulation.log # 主模拟进程日志 └── run_state.json # 运行状态(API 查询用) """ import json import os import logging from datetime import datetime from typing import Dict, Any, Optional class PlatformActionLogger: """单平台动作日志记录器""" def __init__(self, platform: str, base_dir: str): """ 初始化日志记录器 Args: platform: 平台名称 (twitter/reddit) base_dir: 模拟目录的基础路径 """ self.platform = platform self.base_dir = base_dir self.log_dir = os.path.join(base_dir, platform) self.log_path = os.path.join(self.log_dir, "actions.jsonl") self._ensure_dir() def _ensure_dir(self): """确保目录存在""" os.makedirs(self.log_dir, exist_ok=True) def log_action( self, round_num: int, agent_id: int, agent_name: str, action_type: str, action_args: Optional[Dict[str, Any]] = None, result: Optional[str] = None, success: bool = True ): """记录一个动作""" entry = { "round": round_num, "timestamp": datetime.now().isoformat(), "agent_id": agent_id, "agent_name": agent_name, "action_type": action_type, "action_args": action_args or {}, "result": result, "success": success, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_round_start(self, round_num: int, simulated_hour: int): """记录轮次开始""" entry = { "round": round_num, "timestamp": datetime.now().isoformat(), "event_type": "round_start", "simulated_hour": simulated_hour, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_round_end(self, round_num: int, actions_count: int): """记录轮次结束""" entry = { "round": round_num, "timestamp": datetime.now().isoformat(), "event_type": "round_end", "actions_count": actions_count, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_simulation_start(self, config: Dict[str, Any]): """记录模拟开始""" entry = { "timestamp": datetime.now().isoformat(), "event_type": "simulation_start", "platform": self.platform, "total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2, "agents_count": len(config.get("agent_configs", [])), } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_simulation_end(self, total_rounds: int, total_actions: int): """记录模拟结束""" entry = { "timestamp": datetime.now().isoformat(), "event_type": "simulation_end", "platform": self.platform, "total_rounds": total_rounds, "total_actions": total_actions, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') class SimulationLogManager: """ 模拟日志管理器 统一管理所有日志文件,按平台分离 """ def __init__(self, simulation_dir: str): """ 初始化日志管理器 Args: simulation_dir: 模拟目录路径 """ self.simulation_dir = simulation_dir self.twitter_logger: Optional[PlatformActionLogger] = None self.reddit_logger: Optional[PlatformActionLogger] = None self._main_logger: Optional[logging.Logger] = None # 设置主日志 self._setup_main_logger() def _setup_main_logger(self): """设置主模拟日志""" log_path = os.path.join(self.simulation_dir, "simulation.log") # 创建 logger self._main_logger = logging.getLogger(f"simulation.{os.path.basename(self.simulation_dir)}") self._main_logger.setLevel(logging.INFO) self._main_logger.handlers.clear() # 文件处理器 file_handler = logging.FileHandler(log_path, encoding='utf-8', mode='w') file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' )) self._main_logger.addHandler(file_handler) # 控制台处理器 console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(logging.Formatter( '[%(asctime)s] %(message)s', datefmt='%H:%M:%S' )) self._main_logger.addHandler(console_handler) self._main_logger.propagate = False def get_twitter_logger(self) -> PlatformActionLogger: """获取 Twitter 平台日志记录器""" if self.twitter_logger is None: self.twitter_logger = PlatformActionLogger("twitter", self.simulation_dir) return self.twitter_logger def get_reddit_logger(self) -> PlatformActionLogger: """获取 Reddit 平台日志记录器""" if self.reddit_logger is None: self.reddit_logger = PlatformActionLogger("reddit", self.simulation_dir) return self.reddit_logger def log(self, message: str, level: str = "info"): """记录主日志""" if self._main_logger: getattr(self._main_logger, level.lower(), self._main_logger.info)(message) def info(self, message: str): self.log(message, "info") def warning(self, message: str): self.log(message, "warning") def error(self, message: str): self.log(message, "error") def debug(self, message: str): self.log(message, "debug") # ============ 兼容旧接口 ============ class ActionLogger: """ 动作日志记录器(兼容旧接口) 建议使用 SimulationLogManager 代替 """ def __init__(self, log_path: str): self.log_path = log_path self._ensure_dir() def _ensure_dir(self): log_dir = os.path.dirname(self.log_path) if log_dir: os.makedirs(log_dir, exist_ok=True) def log_action( self, round_num: int, platform: str, agent_id: int, agent_name: str, action_type: str, action_args: Optional[Dict[str, Any]] = None, result: Optional[str] = None, success: bool = True ): entry = { "round": round_num, "timestamp": datetime.now().isoformat(), "platform": platform, "agent_id": agent_id, "agent_name": agent_name, "action_type": action_type, "action_args": action_args or {}, "result": result, "success": success, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_round_start(self, round_num: int, simulated_hour: int, platform: str): entry = { "round": round_num, "timestamp": datetime.now().isoformat(), "platform": platform, "event_type": "round_start", "simulated_hour": simulated_hour, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_round_end(self, round_num: int, actions_count: int, platform: str): entry = { "round": round_num, "timestamp": datetime.now().isoformat(), "platform": platform, "event_type": "round_end", "actions_count": actions_count, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_simulation_start(self, platform: str, config: Dict[str, Any]): entry = { "timestamp": datetime.now().isoformat(), "platform": platform, "event_type": "simulation_start", "total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2, "agents_count": len(config.get("agent_configs", [])), } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int): entry = { "timestamp": datetime.now().isoformat(), "platform": platform, "event_type": "simulation_end", "total_rounds": total_rounds, "total_actions": total_actions, } with open(self.log_path, 'a', encoding='utf-8') as f: f.write(json.dumps(entry, ensure_ascii=False) + '\n') # 全局日志实例(兼容旧接口) _global_logger: Optional[ActionLogger] = None def get_logger(log_path: Optional[str] = None) -> ActionLogger: """获取全局日志实例(兼容旧接口)""" global _global_logger if log_path: _global_logger = ActionLogger(log_path) if _global_logger is None: _global_logger = ActionLogger("actions.jsonl") return _global_logger ================================================ FILE: backend/scripts/run_parallel_simulation.py ================================================ """ OASIS 双平台并行模拟预设脚本 同时运行Twitter和Reddit模拟,读取相同的配置文件 功能特性: - 双平台(Twitter + Reddit)并行模拟 - 完成模拟后不立即关闭环境,进入等待命令模式 - 支持通过IPC接收Interview命令 - 支持单个Agent采访和批量采访 - 支持远程关闭环境命令 使用方式: python run_parallel_simulation.py --config simulation_config.json python run_parallel_simulation.py --config simulation_config.json --no-wait # 完成后立即关闭 python run_parallel_simulation.py --config simulation_config.json --twitter-only python run_parallel_simulation.py --config simulation_config.json --reddit-only 日志结构: sim_xxx/ ├── twitter/ │ └── actions.jsonl # Twitter 平台动作日志 ├── reddit/ │ └── actions.jsonl # Reddit 平台动作日志 ├── simulation.log # 主模拟进程日志 └── run_state.json # 运行状态(API 查询用) """ # ============================================================ # 解决 Windows 编码问题:在所有 import 之前设置 UTF-8 编码 # 这是为了修复 OASIS 第三方库读取文件时未指定编码的问题 # ============================================================ import sys import os if sys.platform == 'win32': # 设置 Python 默认 I/O 编码为 UTF-8 # 这会影响所有未指定编码的 open() 调用 os.environ.setdefault('PYTHONUTF8', '1') os.environ.setdefault('PYTHONIOENCODING', 'utf-8') # 重新配置标准输出流为 UTF-8(解决控制台中文乱码) if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') # 强制设置默认编码(影响 open() 函数的默认编码) # 注意:这需要在 Python 启动时就设置,运行时设置可能不生效 # 所以我们还需要 monkey-patch 内置的 open 函数 import builtins _original_open = builtins.open def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None): """ 包装 open() 函数,对于文本模式默认使用 UTF-8 编码 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题 """ # 只对文本模式(非二进制)且未指定编码的情况设置默认编码 if encoding is None and 'b' not in mode: encoding = 'utf-8' return _original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) builtins.open = _utf8_open import argparse import asyncio import json import logging import multiprocessing import random import signal import sqlite3 import warnings from datetime import datetime from typing import Dict, Any, List, Optional, Tuple # 全局变量:用于信号处理 _shutdown_event = None _cleanup_done = False # 添加 backend 目录到路径 # 脚本固定位于 backend/scripts/ 目录 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) # 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) print(f"已加载环境配置: {_env_file}") else: # 尝试加载 backend/.env _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) print(f"已加载环境配置: {_backend_env}") class MaxTokensWarningFilter(logging.Filter): """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" def filter(self, record): # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True # 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 logging.getLogger().addFilter(MaxTokensWarningFilter()) def disable_oasis_logging(): """ 禁用 OASIS 库的详细日志输出 OASIS 的日志太冗余(记录每个 agent 的观察和动作),我们使用自己的 action_logger """ # 禁用 OASIS 的所有日志器 oasis_loggers = [ "social.agent", "social.twitter", "social.rec", "oasis.env", "table", ] for logger_name in oasis_loggers: logger = logging.getLogger(logger_name) logger.setLevel(logging.CRITICAL) # 只记录严重错误 logger.handlers.clear() logger.propagate = False def init_logging_for_simulation(simulation_dir: str): """ 初始化模拟的日志配置 Args: simulation_dir: 模拟目录路径 """ # 禁用 OASIS 的详细日志 disable_oasis_logging() # 清理旧的 log 目录(如果存在) old_log_dir = os.path.join(simulation_dir, "log") if os.path.exists(old_log_dir): import shutil shutil.rmtree(old_log_dir, ignore_errors=True) from action_logger import SimulationLogManager, PlatformActionLogger try: from camel.models import ModelFactory from camel.types import ModelPlatformType import oasis from oasis import ( ActionType, LLMAction, ManualAction, generate_twitter_agent_graph, generate_reddit_agent_graph ) except ImportError as e: print(f"错误: 缺少依赖 {e}") print("请先安装: pip install oasis-ai camel-ai") sys.exit(1) # Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) TWITTER_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, ActionType.REPOST, ActionType.FOLLOW, ActionType.DO_NOTHING, ActionType.QUOTE_POST, ] # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) REDDIT_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, ActionType.CREATE_POST, ActionType.CREATE_COMMENT, ActionType.LIKE_COMMENT, ActionType.DISLIKE_COMMENT, ActionType.SEARCH_POSTS, ActionType.SEARCH_USER, ActionType.TREND, ActionType.REFRESH, ActionType.DO_NOTHING, ActionType.FOLLOW, ActionType.MUTE, ] # IPC相关常量 IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: """命令类型常量""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class ParallelIPCHandler: """ 双平台IPC命令处理器 管理两个平台的环境,处理Interview命令 """ def __init__( self, simulation_dir: str, twitter_env=None, twitter_agent_graph=None, reddit_env=None, reddit_agent_graph=None ): self.simulation_dir = simulation_dir self.twitter_env = twitter_env self.twitter_agent_graph = twitter_agent_graph self.reddit_env = reddit_env self.reddit_agent_graph = reddit_agent_graph self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): """更新环境状态""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, "twitter_available": self.twitter_env is not None, "reddit_available": self.reddit_env is not None, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: """轮询获取待处理命令""" if not os.path.exists(self.commands_dir): return None # 获取命令文件(按时间排序) command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) command_files.sort(key=lambda x: x[1]) for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, OSError): continue return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): """发送响应""" response = { "command_id": command_id, "status": status, "result": result, "error": error, "timestamp": datetime.now().isoformat() } response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) # 删除命令文件 command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) except OSError: pass def _get_env_and_graph(self, platform: str): """ 获取指定平台的环境和agent_graph Args: platform: 平台名称 ("twitter" 或 "reddit") Returns: (env, agent_graph, platform_name) 或 (None, None, None) """ if platform == "twitter" and self.twitter_env: return self.twitter_env, self.twitter_agent_graph, "twitter" elif platform == "reddit" and self.reddit_env: return self.reddit_env, self.reddit_agent_graph, "reddit" else: return None, None, None async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]: """ 在单个平台上执行Interview Returns: 包含结果的字典,或包含error的字典 """ env, agent_graph, actual_platform = self._get_env_and_graph(platform) if not env or not agent_graph: return {"platform": platform, "error": f"{platform}平台不可用"} try: agent = agent_graph.get_agent(agent_id) interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) actions = {agent: interview_action} await env.step(actions) result = self._get_interview_result(agent_id, actual_platform) result["platform"] = actual_platform return result except Exception as e: return {"platform": platform, "error": str(e)} async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool: """ 处理单个Agent采访命令 Args: command_id: 命令ID agent_id: Agent ID prompt: 采访问题 platform: 指定平台(可选) - "twitter": 只采访Twitter平台 - "reddit": 只采访Reddit平台 - None/不指定: 同时采访两个平台,返回整合结果 Returns: True 表示成功,False 表示失败 """ # 如果指定了平台,只采访该平台 if platform in ("twitter", "reddit"): result = await self._interview_single_platform(agent_id, prompt, platform) if "error" in result: self.send_response(command_id, "failed", error=result["error"]) print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}") return False else: self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}, platform={platform}") return True # 未指定平台:同时采访两个平台 if not self.twitter_env and not self.reddit_env: self.send_response(command_id, "failed", error="没有可用的模拟环境") return False results = { "agent_id": agent_id, "prompt": prompt, "platforms": {} } success_count = 0 # 并行采访两个平台 tasks = [] platforms_to_interview = [] if self.twitter_env: tasks.append(self._interview_single_platform(agent_id, prompt, "twitter")) platforms_to_interview.append("twitter") if self.reddit_env: tasks.append(self._interview_single_platform(agent_id, prompt, "reddit")) platforms_to_interview.append("reddit") # 并行执行 platform_results = await asyncio.gather(*tasks) for platform_name, platform_result in zip(platforms_to_interview, platform_results): results["platforms"][platform_name] = platform_result if "error" not in platform_result: success_count += 1 if success_count > 0: self.send_response(command_id, "completed", result=results) print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}") return True else: errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()] self.send_response(command_id, "failed", error="; ".join(errors)) print(f" Interview失败: agent_id={agent_id}, 所有平台都失败") return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool: """ 处理批量采访命令 Args: command_id: 命令ID interviews: [{"agent_id": int, "prompt": str, "platform": str(optional)}, ...] platform: 默认平台(可被每个interview项覆盖) - "twitter": 只采访Twitter平台 - "reddit": 只采访Reddit平台 - None/不指定: 每个Agent同时采访两个平台 """ # 按平台分组 twitter_interviews = [] reddit_interviews = [] both_platforms_interviews = [] # 需要同时采访两个平台的 for interview in interviews: item_platform = interview.get("platform", platform) if item_platform == "twitter": twitter_interviews.append(interview) elif item_platform == "reddit": reddit_interviews.append(interview) else: # 未指定平台:两个平台都采访 both_platforms_interviews.append(interview) # 把 both_platforms_interviews 拆分到两个平台 if both_platforms_interviews: if self.twitter_env: twitter_interviews.extend(both_platforms_interviews) if self.reddit_env: reddit_interviews.extend(both_platforms_interviews) results = {} # 处理Twitter平台的采访 if twitter_interviews and self.twitter_env: try: twitter_actions = {} for interview in twitter_interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") try: agent = self.twitter_agent_graph.get_agent(agent_id) twitter_actions[agent] = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) except Exception as e: print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}") if twitter_actions: await self.twitter_env.step(twitter_actions) for interview in twitter_interviews: agent_id = interview.get("agent_id") result = self._get_interview_result(agent_id, "twitter") result["platform"] = "twitter" results[f"twitter_{agent_id}"] = result except Exception as e: print(f" Twitter批量Interview失败: {e}") # 处理Reddit平台的采访 if reddit_interviews and self.reddit_env: try: reddit_actions = {} for interview in reddit_interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") try: agent = self.reddit_agent_graph.get_agent(agent_id) reddit_actions[agent] = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) except Exception as e: print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}") if reddit_actions: await self.reddit_env.step(reddit_actions) for interview in reddit_interviews: agent_id = interview.get("agent_id") result = self._get_interview_result(agent_id, "reddit") result["platform"] = "reddit" results[f"reddit_{agent_id}"] = result except Exception as e: print(f" Reddit批量Interview失败: {e}") if results: self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) print(f" 批量Interview完成: {len(results)} 个Agent") return True else: self.send_response(command_id, "failed", error="没有成功的采访") return False def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db") result = { "agent_id": agent_id, "response": None, "timestamp": None } if not os.path.exists(db_path): return result try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = ? AND user_id = ? ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) row = cursor.fetchone() if row: user_id, info_json, created_at = row try: info = json.loads(info_json) if info_json else {} result["response"] = info.get("response", info) result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json conn.close() except Exception as e: print(f" 读取Interview结果失败: {e}") return result async def process_commands(self) -> bool: """ 处理所有待处理命令 Returns: True 表示继续运行,False 表示应该退出 """ command = self.poll_command() if not command: return True command_id = command.get("command_id") command_type = command.get("command_type") args = command.get("args", {}) print(f"\n收到IPC命令: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( command_id, args.get("agent_id", 0), args.get("prompt", ""), args.get("platform") ) return True elif command_type == CommandType.BATCH_INTERVIEW: await self.handle_batch_interview( command_id, args.get("interviews", []), args.get("platform") ) return True elif command_type == CommandType.CLOSE_ENV: print("收到关闭环境命令") self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) return False else: self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") return True def load_config(config_path: str) -> Dict[str, Any]: """加载配置文件""" with open(config_path, 'r', encoding='utf-8') as f: return json.load(f) # 需要过滤掉的非核心动作类型(这些动作对分析价值较低) FILTERED_ACTIONS = {'refresh', 'sign_up'} # 动作类型映射表(数据库中的名称 -> 标准名称) ACTION_TYPE_MAP = { 'create_post': 'CREATE_POST', 'like_post': 'LIKE_POST', 'dislike_post': 'DISLIKE_POST', 'repost': 'REPOST', 'quote_post': 'QUOTE_POST', 'follow': 'FOLLOW', 'mute': 'MUTE', 'create_comment': 'CREATE_COMMENT', 'like_comment': 'LIKE_COMMENT', 'dislike_comment': 'DISLIKE_COMMENT', 'search_posts': 'SEARCH_POSTS', 'search_user': 'SEARCH_USER', 'trend': 'TREND', 'do_nothing': 'DO_NOTHING', 'interview': 'INTERVIEW', } def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]: """ 从 simulation_config 中获取 agent_id -> entity_name 的映射 这样可以在 actions.jsonl 中显示真实的实体名称,而不是 "Agent_0" 这样的代号 Args: config: simulation_config.json 的内容 Returns: agent_id -> entity_name 的映射字典 """ agent_names = {} agent_configs = config.get("agent_configs", []) for agent_config in agent_configs: agent_id = agent_config.get("agent_id") entity_name = agent_config.get("entity_name", f"Agent_{agent_id}") if agent_id is not None: agent_names[agent_id] = entity_name return agent_names def fetch_new_actions_from_db( db_path: str, last_rowid: int, agent_names: Dict[int, str] ) -> Tuple[List[Dict[str, Any]], int]: """ 从数据库中获取新的动作记录,并补充完整的上下文信息 Args: db_path: 数据库文件路径 last_rowid: 上次读取的最大 rowid 值(使用 rowid 而不是 created_at,因为不同平台的 created_at 格式不同) agent_names: agent_id -> agent_name 映射 Returns: (actions_list, new_last_rowid) - actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args(含上下文信息) - new_last_rowid: 新的最大 rowid 值 """ actions = [] new_last_rowid = last_rowid if not os.path.exists(db_path): return actions, new_last_rowid try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # 使用 rowid 来追踪已处理的记录(rowid 是 SQLite 的内置自增字段) # 这样可以避免 created_at 格式差异问题(Twitter 用整数,Reddit 用日期时间字符串) cursor.execute(""" SELECT rowid, user_id, action, info FROM trace WHERE rowid > ? ORDER BY rowid ASC """, (last_rowid,)) for rowid, user_id, action, info_json in cursor.fetchall(): # 更新最大 rowid new_last_rowid = rowid # 过滤非核心动作 if action in FILTERED_ACTIONS: continue # 解析动作参数 try: action_args = json.loads(info_json) if info_json else {} except json.JSONDecodeError: action_args = {} # 精简 action_args,只保留关键字段(保留完整内容,不截断) simplified_args = {} if 'content' in action_args: simplified_args['content'] = action_args['content'] if 'post_id' in action_args: simplified_args['post_id'] = action_args['post_id'] if 'comment_id' in action_args: simplified_args['comment_id'] = action_args['comment_id'] if 'quoted_id' in action_args: simplified_args['quoted_id'] = action_args['quoted_id'] if 'new_post_id' in action_args: simplified_args['new_post_id'] = action_args['new_post_id'] if 'follow_id' in action_args: simplified_args['follow_id'] = action_args['follow_id'] if 'query' in action_args: simplified_args['query'] = action_args['query'] if 'like_id' in action_args: simplified_args['like_id'] = action_args['like_id'] if 'dislike_id' in action_args: simplified_args['dislike_id'] = action_args['dislike_id'] # 转换动作类型名称 action_type = ACTION_TYPE_MAP.get(action, action.upper()) # 补充上下文信息(帖子内容、用户名等) _enrich_action_context(cursor, action_type, simplified_args, agent_names) actions.append({ 'agent_id': user_id, 'agent_name': agent_names.get(user_id, f'Agent_{user_id}'), 'action_type': action_type, 'action_args': simplified_args, }) conn.close() except Exception as e: print(f"读取数据库动作失败: {e}") return actions, new_last_rowid def _enrich_action_context( cursor, action_type: str, action_args: Dict[str, Any], agent_names: Dict[int, str] ) -> None: """ 为动作补充上下文信息(帖子内容、用户名等) Args: cursor: 数据库游标 action_type: 动作类型 action_args: 动作参数(会被修改) agent_names: agent_id -> agent_name 映射 """ try: # 点赞/踩帖子:补充帖子内容和作者 if action_type in ('LIKE_POST', 'DISLIKE_POST'): post_id = action_args.get('post_id') if post_id: post_info = _get_post_info(cursor, post_id, agent_names) if post_info: action_args['post_content'] = post_info.get('content', '') action_args['post_author_name'] = post_info.get('author_name', '') # 转发帖子:补充原帖内容和作者 elif action_type == 'REPOST': new_post_id = action_args.get('new_post_id') if new_post_id: # 转发帖子的 original_post_id 指向原帖 cursor.execute(""" SELECT original_post_id FROM post WHERE post_id = ? """, (new_post_id,)) row = cursor.fetchone() if row and row[0]: original_post_id = row[0] original_info = _get_post_info(cursor, original_post_id, agent_names) if original_info: action_args['original_content'] = original_info.get('content', '') action_args['original_author_name'] = original_info.get('author_name', '') # 引用帖子:补充原帖内容、作者和引用评论 elif action_type == 'QUOTE_POST': quoted_id = action_args.get('quoted_id') new_post_id = action_args.get('new_post_id') if quoted_id: original_info = _get_post_info(cursor, quoted_id, agent_names) if original_info: action_args['original_content'] = original_info.get('content', '') action_args['original_author_name'] = original_info.get('author_name', '') # 获取引用帖子的评论内容(quote_content) if new_post_id: cursor.execute(""" SELECT quote_content FROM post WHERE post_id = ? """, (new_post_id,)) row = cursor.fetchone() if row and row[0]: action_args['quote_content'] = row[0] # 关注用户:补充被关注用户的名称 elif action_type == 'FOLLOW': follow_id = action_args.get('follow_id') if follow_id: # 从 follow 表获取 followee_id cursor.execute(""" SELECT followee_id FROM follow WHERE follow_id = ? """, (follow_id,)) row = cursor.fetchone() if row: followee_id = row[0] target_name = _get_user_name(cursor, followee_id, agent_names) if target_name: action_args['target_user_name'] = target_name # 屏蔽用户:补充被屏蔽用户的名称 elif action_type == 'MUTE': # 从 action_args 中获取 user_id 或 target_id target_id = action_args.get('user_id') or action_args.get('target_id') if target_id: target_name = _get_user_name(cursor, target_id, agent_names) if target_name: action_args['target_user_name'] = target_name # 点赞/踩评论:补充评论内容和作者 elif action_type in ('LIKE_COMMENT', 'DISLIKE_COMMENT'): comment_id = action_args.get('comment_id') if comment_id: comment_info = _get_comment_info(cursor, comment_id, agent_names) if comment_info: action_args['comment_content'] = comment_info.get('content', '') action_args['comment_author_name'] = comment_info.get('author_name', '') # 发表评论:补充所评论的帖子信息 elif action_type == 'CREATE_COMMENT': post_id = action_args.get('post_id') if post_id: post_info = _get_post_info(cursor, post_id, agent_names) if post_info: action_args['post_content'] = post_info.get('content', '') action_args['post_author_name'] = post_info.get('author_name', '') except Exception as e: # 补充上下文失败不影响主流程 print(f"补充动作上下文失败: {e}") def _get_post_info( cursor, post_id: int, agent_names: Dict[int, str] ) -> Optional[Dict[str, str]]: """ 获取帖子信息 Args: cursor: 数据库游标 post_id: 帖子ID agent_names: agent_id -> agent_name 映射 Returns: 包含 content 和 author_name 的字典,或 None """ try: cursor.execute(""" SELECT p.content, p.user_id, u.agent_id FROM post p LEFT JOIN user u ON p.user_id = u.user_id WHERE p.post_id = ? """, (post_id,)) row = cursor.fetchone() if row: content = row[0] or '' user_id = row[1] agent_id = row[2] # 优先使用 agent_names 中的名称 author_name = '' if agent_id is not None and agent_id in agent_names: author_name = agent_names[agent_id] elif user_id: # 从 user 表获取名称 cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,)) user_row = cursor.fetchone() if user_row: author_name = user_row[0] or user_row[1] or '' return {'content': content, 'author_name': author_name} except Exception: pass return None def _get_user_name( cursor, user_id: int, agent_names: Dict[int, str] ) -> Optional[str]: """ 获取用户名称 Args: cursor: 数据库游标 user_id: 用户ID agent_names: agent_id -> agent_name 映射 Returns: 用户名称,或 None """ try: cursor.execute(""" SELECT agent_id, name, user_name FROM user WHERE user_id = ? """, (user_id,)) row = cursor.fetchone() if row: agent_id = row[0] name = row[1] user_name = row[2] # 优先使用 agent_names 中的名称 if agent_id is not None and agent_id in agent_names: return agent_names[agent_id] return name or user_name or '' except Exception: pass return None def _get_comment_info( cursor, comment_id: int, agent_names: Dict[int, str] ) -> Optional[Dict[str, str]]: """ 获取评论信息 Args: cursor: 数据库游标 comment_id: 评论ID agent_names: agent_id -> agent_name 映射 Returns: 包含 content 和 author_name 的字典,或 None """ try: cursor.execute(""" SELECT c.content, c.user_id, u.agent_id FROM comment c LEFT JOIN user u ON c.user_id = u.user_id WHERE c.comment_id = ? """, (comment_id,)) row = cursor.fetchone() if row: content = row[0] or '' user_id = row[1] agent_id = row[2] # 优先使用 agent_names 中的名称 author_name = '' if agent_id is not None and agent_id in agent_names: author_name = agent_names[agent_id] elif user_id: # 从 user 表获取名称 cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,)) user_row = cursor.fetchone() if user_row: author_name = user_row[0] or user_row[1] or '' return {'content': content, 'author_name': author_name} except Exception: pass return None def create_model(config: Dict[str, Any], use_boost: bool = False): """ 创建LLM模型 支持双 LLM 配置,用于并行模拟时提速: - 通用配置:LLM_API_KEY, LLM_BASE_URL, LLM_MODEL_NAME - 加速配置(可选):LLM_BOOST_API_KEY, LLM_BOOST_BASE_URL, LLM_BOOST_MODEL_NAME 如果配置了加速 LLM,并行模拟时可以让不同平台使用不同的 API 服务商,提高并发能力。 Args: config: 模拟配置字典 use_boost: 是否使用加速 LLM 配置(如果可用) """ # 检查是否有加速配置 boost_api_key = os.environ.get("LLM_BOOST_API_KEY", "") boost_base_url = os.environ.get("LLM_BOOST_BASE_URL", "") boost_model = os.environ.get("LLM_BOOST_MODEL_NAME", "") has_boost_config = bool(boost_api_key) # 根据参数和配置情况选择使用哪个 LLM if use_boost and has_boost_config: # 使用加速配置 llm_api_key = boost_api_key llm_base_url = boost_base_url llm_model = boost_model or os.environ.get("LLM_MODEL_NAME", "") config_label = "[加速LLM]" else: # 使用通用配置 llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") config_label = "[通用LLM]" # 如果 .env 中没有模型名,则使用 config 作为备用 if not llm_model: llm_model = config.get("llm_model", "gpt-4o-mini") # 设置 camel-ai 所需的环境变量 if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url print(f"{config_label} model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, ) def get_active_agents_for_round( env, config: Dict[str, Any], current_hour: int, round_num: int ) -> List: """根据时间和配置决定本轮激活哪些Agent""" time_config = config.get("time_config", {}) agent_configs = config.get("agent_configs", []) base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) if current_hour in peak_hours: multiplier = time_config.get("peak_activity_multiplier", 1.5) elif current_hour in off_peak_hours: multiplier = time_config.get("off_peak_activity_multiplier", 0.3) else: multiplier = 1.0 target_count = int(random.uniform(base_min, base_max) * multiplier) candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) if current_hour not in active_hours: continue if random.random() < activity_level: candidates.append(agent_id) selected_ids = random.sample( candidates, min(target_count, len(candidates)) ) if candidates else [] active_agents = [] for agent_id in selected_ids: try: agent = env.agent_graph.get_agent(agent_id) active_agents.append((agent_id, agent)) except Exception: pass return active_agents class PlatformSimulation: """平台模拟结果容器""" def __init__(self): self.env = None self.agent_graph = None self.total_actions = 0 async def run_twitter_simulation( config: Dict[str, Any], simulation_dir: str, action_logger: Optional[PlatformActionLogger] = None, main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None ) -> PlatformSimulation: """运行Twitter模拟 Args: config: 模拟配置 simulation_dir: 模拟目录 action_logger: 动作日志记录器 main_logger: 主日志管理器 max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) Returns: PlatformSimulation: 包含env和agent_graph的结果对象 """ result = PlatformSimulation() def log_info(msg): if main_logger: main_logger.info(f"[Twitter] {msg}") print(f"[Twitter] {msg}") log_info("初始化...") # Twitter 使用通用 LLM 配置 model = create_model(config, use_boost=False) # OASIS Twitter使用CSV格式 profile_path = os.path.join(simulation_dir, "twitter_profiles.csv") if not os.path.exists(profile_path): log_info(f"错误: Profile文件不存在: {profile_path}") return result result.agent_graph = await generate_twitter_agent_graph( profile_path=profile_path, model=model, available_actions=TWITTER_ACTIONS, ) # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) agent_names = get_agent_names_from_config(config) # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') db_path = os.path.join(simulation_dir, "twitter_simulation.db") if os.path.exists(db_path): os.remove(db_path) result.env = oasis.make( agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) await result.env.reset() log_info("环境已启动") if action_logger: action_logger.log_simulation_start(config) total_actions = 0 last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异) # 执行初始事件 event_config = config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) # 记录 round 0 开始(初始事件阶段) if action_logger: action_logger.log_round_start(0, 0) # round 0, simulated_hour 0 initial_action_count = 0 if initial_posts: initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: agent = result.env.agent_graph.get_agent(agent_id) initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} ) if action_logger: action_logger.log_action( round_num=0, agent_id=agent_id, agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"), action_type="CREATE_POST", action_args={"content": content} ) total_actions += 1 initial_action_count += 1 except Exception: pass if initial_actions: await result.env.step(initial_actions) log_info(f"已发布 {len(initial_actions)} 条初始帖子") # 记录 round 0 结束 if action_logger: action_logger.log_round_end(0, initial_action_count) # 主模拟循环 time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") start_time = datetime.now() for round_num in range(total_rounds): # 检查是否收到退出信号 if _shutdown_event and _shutdown_event.is_set(): if main_logger: main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟") break simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = get_active_agents_for_round( result.env, config, simulated_hour, round_num ) # 无论是否有活跃agent,都记录round开始 if action_logger: action_logger.log_round_start(round_num + 1, simulated_hour) if not active_agents: # 没有活跃agent时也记录round结束(actions_count=0) if action_logger: action_logger.log_round_end(round_num + 1, 0) continue actions = {agent: LLMAction() for _, agent in active_agents} await result.env.step(actions) # 从数据库获取实际执行的动作并记录 actual_actions, last_rowid = fetch_new_actions_from_db( db_path, last_rowid, agent_names ) round_action_count = 0 for action_data in actual_actions: if action_logger: action_logger.log_action( round_num=round_num + 1, agent_id=action_data['agent_id'], agent_name=action_data['agent_name'], action_type=action_data['action_type'], action_args=action_data['action_args'] ) total_actions += 1 round_action_count += 1 if action_logger: action_logger.log_round_end(round_num + 1, round_action_count) if (round_num + 1) % 20 == 0: progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") # 注意:不关闭环境,保留给Interview使用 if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) result.total_actions = total_actions elapsed = (datetime.now() - start_time).total_seconds() log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") return result async def run_reddit_simulation( config: Dict[str, Any], simulation_dir: str, action_logger: Optional[PlatformActionLogger] = None, main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None ) -> PlatformSimulation: """运行Reddit模拟 Args: config: 模拟配置 simulation_dir: 模拟目录 action_logger: 动作日志记录器 main_logger: 主日志管理器 max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) Returns: PlatformSimulation: 包含env和agent_graph的结果对象 """ result = PlatformSimulation() def log_info(msg): if main_logger: main_logger.info(f"[Reddit] {msg}") print(f"[Reddit] {msg}") log_info("初始化...") # Reddit 使用加速 LLM 配置(如果有的话,否则回退到通用配置) model = create_model(config, use_boost=True) profile_path = os.path.join(simulation_dir, "reddit_profiles.json") if not os.path.exists(profile_path): log_info(f"错误: Profile文件不存在: {profile_path}") return result result.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=REDDIT_ACTIONS, ) # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) agent_names = get_agent_names_from_config(config) # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') db_path = os.path.join(simulation_dir, "reddit_simulation.db") if os.path.exists(db_path): os.remove(db_path) result.env = oasis.make( agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) await result.env.reset() log_info("环境已启动") if action_logger: action_logger.log_simulation_start(config) total_actions = 0 last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异) # 执行初始事件 event_config = config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) # 记录 round 0 开始(初始事件阶段) if action_logger: action_logger.log_round_start(0, 0) # round 0, simulated_hour 0 initial_action_count = 0 if initial_posts: initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: agent = result.env.agent_graph.get_agent(agent_id) if agent in initial_actions: if not isinstance(initial_actions[agent], list): initial_actions[agent] = [initial_actions[agent]] initial_actions[agent].append(ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} )) else: initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} ) if action_logger: action_logger.log_action( round_num=0, agent_id=agent_id, agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"), action_type="CREATE_POST", action_args={"content": content} ) total_actions += 1 initial_action_count += 1 except Exception: pass if initial_actions: await result.env.step(initial_actions) log_info(f"已发布 {len(initial_actions)} 条初始帖子") # 记录 round 0 结束 if action_logger: action_logger.log_round_end(0, initial_action_count) # 主模拟循环 time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") start_time = datetime.now() for round_num in range(total_rounds): # 检查是否收到退出信号 if _shutdown_event and _shutdown_event.is_set(): if main_logger: main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟") break simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = get_active_agents_for_round( result.env, config, simulated_hour, round_num ) # 无论是否有活跃agent,都记录round开始 if action_logger: action_logger.log_round_start(round_num + 1, simulated_hour) if not active_agents: # 没有活跃agent时也记录round结束(actions_count=0) if action_logger: action_logger.log_round_end(round_num + 1, 0) continue actions = {agent: LLMAction() for _, agent in active_agents} await result.env.step(actions) # 从数据库获取实际执行的动作并记录 actual_actions, last_rowid = fetch_new_actions_from_db( db_path, last_rowid, agent_names ) round_action_count = 0 for action_data in actual_actions: if action_logger: action_logger.log_action( round_num=round_num + 1, agent_id=action_data['agent_id'], agent_name=action_data['agent_name'], action_type=action_data['action_type'], action_args=action_data['action_args'] ) total_actions += 1 round_action_count += 1 if action_logger: action_logger.log_round_end(round_num + 1, round_action_count) if (round_num + 1) % 20 == 0: progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") # 注意:不关闭环境,保留给Interview使用 if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) result.total_actions = total_actions elapsed = (datetime.now() - start_time).total_seconds() log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") return result async def main(): parser = argparse.ArgumentParser(description='OASIS双平台并行模拟') parser.add_argument( '--config', type=str, required=True, help='配置文件路径 (simulation_config.json)' ) parser.add_argument( '--twitter-only', action='store_true', help='只运行Twitter模拟' ) parser.add_argument( '--reddit-only', action='store_true', help='只运行Reddit模拟' ) parser.add_argument( '--max-rounds', type=int, default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) parser.add_argument( '--no-wait', action='store_true', default=False, help='模拟完成后立即关闭环境,不进入等待命令模式' ) args = parser.parse_args() # 在 main 函数开始时创建 shutdown 事件,确保整个程序都能响应退出信号 global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) config = load_config(args.config) simulation_dir = os.path.dirname(args.config) or "." wait_for_commands = not args.no_wait # 初始化日志配置(禁用 OASIS 日志,清理旧文件) init_logging_for_simulation(simulation_dir) # 创建日志管理器 log_manager = SimulationLogManager(simulation_dir) twitter_logger = log_manager.get_twitter_logger() reddit_logger = log_manager.get_reddit_logger() log_manager.info("=" * 60) log_manager.info("OASIS 双平台并行模拟") log_manager.info(f"配置文件: {args.config}") log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}") log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}") log_manager.info("=" * 60) time_config = config.get("time_config", {}) total_hours = time_config.get('total_simulation_hours', 72) minutes_per_round = time_config.get('minutes_per_round', 30) config_total_rounds = (total_hours * 60) // minutes_per_round log_manager.info(f"模拟参数:") log_manager.info(f" - 总模拟时长: {total_hours}小时") log_manager.info(f" - 每轮时间: {minutes_per_round}分钟") log_manager.info(f" - 配置总轮数: {config_total_rounds}") if args.max_rounds: log_manager.info(f" - 最大轮数限制: {args.max_rounds}") if args.max_rounds < config_total_rounds: log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)") log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}") log_manager.info("日志结构:") log_manager.info(f" - 主日志: simulation.log") log_manager.info(f" - Twitter动作: twitter/actions.jsonl") log_manager.info(f" - Reddit动作: reddit/actions.jsonl") log_manager.info("=" * 60) start_time = datetime.now() # 存储两个平台的模拟结果 twitter_result: Optional[PlatformSimulation] = None reddit_result: Optional[PlatformSimulation] = None if args.twitter_only: twitter_result = await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds) elif args.reddit_only: reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds) else: # 并行运行(每个平台使用独立的日志记录器) results = await asyncio.gather( run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds), run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds), ) twitter_result, reddit_result = results total_elapsed = (datetime.now() - start_time).total_seconds() log_manager.info("=" * 60) log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒") # 是否进入等待命令模式 if wait_for_commands: log_manager.info("") log_manager.info("=" * 60) log_manager.info("进入等待命令模式 - 环境保持运行") log_manager.info("支持的命令: interview, batch_interview, close_env") log_manager.info("=" * 60) # 创建IPC处理器 ipc_handler = ParallelIPCHandler( simulation_dir=simulation_dir, twitter_env=twitter_result.env if twitter_result else None, twitter_agent_graph=twitter_result.agent_graph if twitter_result else None, reddit_env=reddit_result.env if reddit_result else None, reddit_agent_graph=reddit_result.agent_graph if reddit_result else None ) ipc_handler.update_status("alive") # 等待命令循环(使用全局 _shutdown_event) try: while not _shutdown_event.is_set(): should_continue = await ipc_handler.process_commands() if not should_continue: break # 使用 wait_for 替代 sleep,这样可以响应 shutdown_event try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) break # 收到退出信号 except asyncio.TimeoutError: pass # 超时继续循环 except KeyboardInterrupt: print("\n收到中断信号") except asyncio.CancelledError: print("\n任务被取消") except Exception as e: print(f"\n命令处理出错: {e}") log_manager.info("\n关闭环境...") ipc_handler.update_status("stopped") # 关闭环境 if twitter_result and twitter_result.env: await twitter_result.env.close() log_manager.info("[Twitter] 环境已关闭") if reddit_result and reddit_result.env: await reddit_result.env.close() log_manager.info("[Reddit] 环境已关闭") log_manager.info("=" * 60) log_manager.info(f"全部完成!") log_manager.info(f"日志文件:") log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}") log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}") log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}") log_manager.info("=" * 60) def setup_signal_handlers(loop=None): """ 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 持久化模拟场景:模拟完成后不退出,等待 interview 命令 当收到终止信号时,需要: 1. 通知 asyncio 循环退出等待 2. 让程序有机会正常清理资源(关闭数据库、环境等) 3. 然后才退出 """ def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" print(f"\n收到 {sig_name} 信号,正在退出...") if not _cleanup_done: _cleanup_done = True # 设置事件通知 asyncio 循环退出(让循环有机会清理资源) if _shutdown_event: _shutdown_event.set() # 不要直接 sys.exit(),让 asyncio 循环正常退出并清理资源 # 如果是重复收到信号,才强制退出 else: print("强制退出...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": setup_signal_handlers() try: asyncio.run(main()) except KeyboardInterrupt: print("\n程序被中断") except SystemExit: pass finally: # 清理 multiprocessing 资源跟踪器(防止退出时的警告) try: from multiprocessing import resource_tracker resource_tracker._resource_tracker._stop() except Exception: pass print("模拟进程已退出") ================================================ FILE: backend/scripts/run_reddit_simulation.py ================================================ """ OASIS Reddit模拟预设脚本 此脚本读取配置文件中的参数来执行模拟,实现全程自动化 功能特性: - 完成模拟后不立即关闭环境,进入等待命令模式 - 支持通过IPC接收Interview命令 - 支持单个Agent采访和批量采访 - 支持远程关闭环境命令 使用方式: python run_reddit_simulation.py --config /path/to/simulation_config.json python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 """ import argparse import asyncio import json import logging import os import random import signal import sys import sqlite3 from datetime import datetime from typing import Dict, Any, List, Optional # 全局变量:用于信号处理 _shutdown_event = None _cleanup_done = False # 添加项目路径 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) # 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) else: _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) import re class UnicodeFormatter(logging.Formatter): """自定义格式化器,将 Unicode 转义序列转换为可读字符""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') def format(self, record): result = super().format(record) def replace_unicode(match): try: return chr(int(match.group(1), 16)) except (ValueError, OverflowError): return match.group(0) return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) class MaxTokensWarningFilter(logging.Filter): """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" def filter(self, record): # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True # 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): """配置 OASIS 的日志,使用固定名称的日志文件""" os.makedirs(log_dir, exist_ok=True) # 清理旧的日志文件 for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): try: os.remove(old_log) except OSError: pass formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") loggers_config = { "social.agent": os.path.join(log_dir, "social.agent.log"), "social.twitter": os.path.join(log_dir, "social.twitter.log"), "social.rec": os.path.join(log_dir, "social.rec.log"), "oasis.env": os.path.join(log_dir, "oasis.env.log"), "table": os.path.join(log_dir, "table.log"), } for logger_name, log_file in loggers_config.items(): logger = logging.getLogger(logger_name) logger.setLevel(logging.DEBUG) logger.handlers.clear() file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.propagate = False try: from camel.models import ModelFactory from camel.types import ModelPlatformType import oasis from oasis import ( ActionType, LLMAction, ManualAction, generate_reddit_agent_graph ) except ImportError as e: print(f"错误: 缺少依赖 {e}") print("请先安装: pip install oasis-ai camel-ai") sys.exit(1) # IPC相关常量 IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: """命令类型常量""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: """IPC命令处理器""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir self.env = env self.agent_graph = agent_graph self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): """更新环境状态""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: """轮询获取待处理命令""" if not os.path.exists(self.commands_dir): return None # 获取命令文件(按时间排序) command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) command_files.sort(key=lambda x: x[1]) for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, OSError): continue return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): """发送响应""" response = { "command_id": command_id, "status": status, "result": result, "error": error, "timestamp": datetime.now().isoformat() } response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) # 删除命令文件 command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) except OSError: pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: """ 处理单个Agent采访命令 Returns: True 表示成功,False 表示失败 """ try: # 获取Agent agent = self.agent_graph.get_agent(agent_id) # 创建Interview动作 interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) # 执行Interview actions = {agent: interview_action} await self.env.step(actions) # 从数据库获取结果 result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}") return True except Exception as e: error_msg = str(e) print(f" Interview失败: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: """ 处理批量采访命令 Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ try: # 构建动作字典 actions = {} agent_prompts = {} # 记录每个agent的prompt for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") try: agent = self.agent_graph.get_agent(agent_id) actions[agent] = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) agent_prompts[agent_id] = prompt except Exception as e: print(f" 警告: 无法获取Agent {agent_id}: {e}") if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False # 执行批量Interview await self.env.step(actions) # 获取所有结果 results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) results[agent_id] = result self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) print(f" 批量Interview完成: {len(results)} 个Agent") return True except Exception as e: error_msg = str(e) print(f" 批量Interview失败: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") result = { "agent_id": agent_id, "response": None, "timestamp": None } if not os.path.exists(db_path): return result try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = ? AND user_id = ? ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) row = cursor.fetchone() if row: user_id, info_json, created_at = row try: info = json.loads(info_json) if info_json else {} result["response"] = info.get("response", info) result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json conn.close() except Exception as e: print(f" 读取Interview结果失败: {e}") return result async def process_commands(self) -> bool: """ 处理所有待处理命令 Returns: True 表示继续运行,False 表示应该退出 """ command = self.poll_command() if not command: return True command_id = command.get("command_id") command_type = command.get("command_type") args = command.get("args", {}) print(f"\n收到IPC命令: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( command_id, args.get("agent_id", 0), args.get("prompt", "") ) return True elif command_type == CommandType.BATCH_INTERVIEW: await self.handle_batch_interview( command_id, args.get("interviews", []) ) return True elif command_type == CommandType.CLOSE_ENV: print("收到关闭环境命令") self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) return False else: self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") return True class RedditSimulationRunner: """Reddit模拟运行器""" # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, ActionType.CREATE_POST, ActionType.CREATE_COMMENT, ActionType.LIKE_COMMENT, ActionType.DISLIKE_COMMENT, ActionType.SEARCH_POSTS, ActionType.SEARCH_USER, ActionType.TREND, ActionType.REFRESH, ActionType.DO_NOTHING, ActionType.FOLLOW, ActionType.MUTE, ] def __init__(self, config_path: str, wait_for_commands: bool = True): """ 初始化模拟运行器 Args: config_path: 配置文件路径 (simulation_config.json) wait_for_commands: 模拟完成后是否等待命令(默认True) """ self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) self.wait_for_commands = wait_for_commands self.env = None self.agent_graph = None self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: """加载配置文件""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) def _get_profile_path(self) -> str: """获取Profile文件路径""" return os.path.join(self.simulation_dir, "reddit_profiles.json") def _get_db_path(self) -> str: """获取数据库路径""" return os.path.join(self.simulation_dir, "reddit_simulation.db") def _create_model(self): """ 创建LLM模型 统一使用项目根目录 .env 文件中的配置(优先级最高): - LLM_API_KEY: API密钥 - LLM_BASE_URL: API基础URL - LLM_MODEL_NAME: 模型名称 """ # 优先从 .env 读取配置 llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") # 如果 .env 中没有,则使用 config 作为备用 if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") # 设置 camel-ai 所需的环境变量 if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, ) def _get_active_agents_for_round( self, env, current_hour: int, round_num: int ) -> List: """ 根据时间和配置决定本轮激活哪些Agent """ time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) if current_hour in peak_hours: multiplier = time_config.get("peak_activity_multiplier", 1.5) elif current_hour in off_peak_hours: multiplier = time_config.get("off_peak_activity_multiplier", 0.3) else: multiplier = 1.0 target_count = int(random.uniform(base_min, base_max) * multiplier) candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) if current_hour not in active_hours: continue if random.random() < activity_level: candidates.append(agent_id) selected_ids = random.sample( candidates, min(target_count, len(candidates)) ) if candidates else [] active_agents = [] for agent_id in selected_ids: try: agent = env.agent_graph.get_agent(agent_id) active_agents.append((agent_id, agent)) except Exception: pass return active_agents async def run(self, max_rounds: int = None): """运行Reddit模拟 Args: max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) """ print("=" * 60) print("OASIS Reddit模拟") print(f"配置文件: {self.config_path}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") print(f"\n模拟参数:") print(f" - 总模拟时长: {total_hours}小时") print(f" - 每轮时间: {minutes_per_round}分钟") print(f" - 总轮数: {total_rounds}") if max_rounds: print(f" - 最大轮数限制: {max_rounds}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") print("\n初始化LLM模型...") model = self._create_model() print("加载Agent Profile...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): print(f"错误: Profile文件不存在: {profile_path}") return self.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, ) db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) print(f"已删除旧数据库: {db_path}") print("创建OASIS环境...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) await self.env.reset() print("环境初始化完成\n") # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") # 执行初始事件 event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) if initial_posts: print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: agent = self.env.agent_graph.get_agent(agent_id) if agent in initial_actions: if not isinstance(initial_actions[agent], list): initial_actions[agent] = [initial_actions[agent]] initial_actions[agent].append(ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} )) else: initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} ) except Exception as e: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") if initial_actions: await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") # 主模拟循环 print("\n开始模拟循环...") start_time = datetime.now() for round_num in range(total_rounds): simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) if not active_agents: continue actions = { agent: LLMAction() for _, agent in active_agents } await self.env.step(actions) if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 print(f" [Day {simulated_day}, {simulated_hour:02d}:00] " f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") total_elapsed = (datetime.now() - start_time).total_seconds() print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") # 是否进入等待命令模式 if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") print("支持的命令: interview, batch_interview, close_env") print("=" * 60) self.ipc_handler.update_status("alive") # 等待命令循环(使用全局 _shutdown_event) try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() if not should_continue: break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) break # 收到退出信号 except asyncio.TimeoutError: pass except KeyboardInterrupt: print("\n收到中断信号") except asyncio.CancelledError: print("\n任务被取消") except Exception as e: print(f"\n命令处理出错: {e}") print("\n关闭环境...") # 关闭环境 self.ipc_handler.update_status("stopped") await self.env.close() print("环境已关闭") print("=" * 60) async def main(): parser = argparse.ArgumentParser(description='OASIS Reddit模拟') parser.add_argument( '--config', type=str, required=True, help='配置文件路径 (simulation_config.json)' ) parser.add_argument( '--max-rounds', type=int, default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) parser.add_argument( '--no-wait', action='store_true', default=False, help='模拟完成后立即关闭环境,不进入等待命令模式' ) args = parser.parse_args() # 在 main 函数开始时创建 shutdown 事件 global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) # 初始化日志配置(使用固定文件名,清理旧日志) simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) runner = RedditSimulationRunner( config_path=args.config, wait_for_commands=not args.no_wait ) await runner.run(max_rounds=args.max_rounds) def setup_signal_handlers(): """ 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 让程序有机会正常清理资源(关闭数据库、环境等) """ def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" print(f"\n收到 {sig_name} 信号,正在退出...") if not _cleanup_done: _cleanup_done = True if _shutdown_event: _shutdown_event.set() else: # 重复收到信号才强制退出 print("强制退出...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": setup_signal_handlers() try: asyncio.run(main()) except KeyboardInterrupt: print("\n程序被中断") except SystemExit: pass finally: print("模拟进程已退出") ================================================ FILE: backend/scripts/run_twitter_simulation.py ================================================ """ OASIS Twitter模拟预设脚本 此脚本读取配置文件中的参数来执行模拟,实现全程自动化 功能特性: - 完成模拟后不立即关闭环境,进入等待命令模式 - 支持通过IPC接收Interview命令 - 支持单个Agent采访和批量采访 - 支持远程关闭环境命令 使用方式: python run_twitter_simulation.py --config /path/to/simulation_config.json python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 """ import argparse import asyncio import json import logging import os import random import signal import sys import sqlite3 from datetime import datetime from typing import Dict, Any, List, Optional # 全局变量:用于信号处理 _shutdown_event = None _cleanup_done = False # 添加项目路径 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) # 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) else: _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) import re class UnicodeFormatter(logging.Formatter): """自定义格式化器,将 Unicode 转义序列转换为可读字符""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') def format(self, record): result = super().format(record) def replace_unicode(match): try: return chr(int(match.group(1), 16)) except (ValueError, OverflowError): return match.group(0) return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) class MaxTokensWarningFilter(logging.Filter): """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" def filter(self, record): # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True # 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): """配置 OASIS 的日志,使用固定名称的日志文件""" os.makedirs(log_dir, exist_ok=True) # 清理旧的日志文件 for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): try: os.remove(old_log) except OSError: pass formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") loggers_config = { "social.agent": os.path.join(log_dir, "social.agent.log"), "social.twitter": os.path.join(log_dir, "social.twitter.log"), "social.rec": os.path.join(log_dir, "social.rec.log"), "oasis.env": os.path.join(log_dir, "oasis.env.log"), "table": os.path.join(log_dir, "table.log"), } for logger_name, log_file in loggers_config.items(): logger = logging.getLogger(logger_name) logger.setLevel(logging.DEBUG) logger.handlers.clear() file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.propagate = False try: from camel.models import ModelFactory from camel.types import ModelPlatformType import oasis from oasis import ( ActionType, LLMAction, ManualAction, generate_twitter_agent_graph ) except ImportError as e: print(f"错误: 缺少依赖 {e}") print("请先安装: pip install oasis-ai camel-ai") sys.exit(1) # IPC相关常量 IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: """命令类型常量""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: """IPC命令处理器""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir self.env = env self.agent_graph = agent_graph self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): """更新环境状态""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: """轮询获取待处理命令""" if not os.path.exists(self.commands_dir): return None # 获取命令文件(按时间排序) command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) command_files.sort(key=lambda x: x[1]) for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, OSError): continue return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): """发送响应""" response = { "command_id": command_id, "status": status, "result": result, "error": error, "timestamp": datetime.now().isoformat() } response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) # 删除命令文件 command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) except OSError: pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: """ 处理单个Agent采访命令 Returns: True 表示成功,False 表示失败 """ try: # 获取Agent agent = self.agent_graph.get_agent(agent_id) # 创建Interview动作 interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) # 执行Interview actions = {agent: interview_action} await self.env.step(actions) # 从数据库获取结果 result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}") return True except Exception as e: error_msg = str(e) print(f" Interview失败: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: """ 处理批量采访命令 Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ try: # 构建动作字典 actions = {} agent_prompts = {} # 记录每个agent的prompt for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") try: agent = self.agent_graph.get_agent(agent_id) actions[agent] = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) agent_prompts[agent_id] = prompt except Exception as e: print(f" 警告: 无法获取Agent {agent_id}: {e}") if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False # 执行批量Interview await self.env.step(actions) # 获取所有结果 results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) results[agent_id] = result self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) print(f" 批量Interview完成: {len(results)} 个Agent") return True except Exception as e: error_msg = str(e) print(f" 批量Interview失败: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, "twitter_simulation.db") result = { "agent_id": agent_id, "response": None, "timestamp": None } if not os.path.exists(db_path): return result try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = ? AND user_id = ? ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) row = cursor.fetchone() if row: user_id, info_json, created_at = row try: info = json.loads(info_json) if info_json else {} result["response"] = info.get("response", info) result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json conn.close() except Exception as e: print(f" 读取Interview结果失败: {e}") return result async def process_commands(self) -> bool: """ 处理所有待处理命令 Returns: True 表示继续运行,False 表示应该退出 """ command = self.poll_command() if not command: return True command_id = command.get("command_id") command_type = command.get("command_type") args = command.get("args", {}) print(f"\n收到IPC命令: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( command_id, args.get("agent_id", 0), args.get("prompt", "") ) return True elif command_type == CommandType.BATCH_INTERVIEW: await self.handle_batch_interview( command_id, args.get("interviews", []) ) return True elif command_type == CommandType.CLOSE_ENV: print("收到关闭环境命令") self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) return False else: self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") return True class TwitterSimulationRunner: """Twitter模拟运行器""" # Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, ActionType.REPOST, ActionType.FOLLOW, ActionType.DO_NOTHING, ActionType.QUOTE_POST, ] def __init__(self, config_path: str, wait_for_commands: bool = True): """ 初始化模拟运行器 Args: config_path: 配置文件路径 (simulation_config.json) wait_for_commands: 模拟完成后是否等待命令(默认True) """ self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) self.wait_for_commands = wait_for_commands self.env = None self.agent_graph = None self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: """加载配置文件""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) def _get_profile_path(self) -> str: """获取Profile文件路径(OASIS Twitter使用CSV格式)""" return os.path.join(self.simulation_dir, "twitter_profiles.csv") def _get_db_path(self) -> str: """获取数据库路径""" return os.path.join(self.simulation_dir, "twitter_simulation.db") def _create_model(self): """ 创建LLM模型 统一使用项目根目录 .env 文件中的配置(优先级最高): - LLM_API_KEY: API密钥 - LLM_BASE_URL: API基础URL - LLM_MODEL_NAME: 模型名称 """ # 优先从 .env 读取配置 llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") # 如果 .env 中没有,则使用 config 作为备用 if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") # 设置 camel-ai 所需的环境变量 if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, ) def _get_active_agents_for_round( self, env, current_hour: int, round_num: int ) -> List: """ 根据时间和配置决定本轮激活哪些Agent Args: env: OASIS环境 current_hour: 当前模拟小时(0-23) round_num: 当前轮数 Returns: 激活的Agent列表 """ time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) # 基础激活数量 base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) # 根据时段调整 peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) if current_hour in peak_hours: multiplier = time_config.get("peak_activity_multiplier", 1.5) elif current_hour in off_peak_hours: multiplier = time_config.get("off_peak_activity_multiplier", 0.3) else: multiplier = 1.0 target_count = int(random.uniform(base_min, base_max) * multiplier) # 根据每个Agent的配置计算激活概率 candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) # 检查是否在活跃时间 if current_hour not in active_hours: continue # 根据活跃度计算概率 if random.random() < activity_level: candidates.append(agent_id) # 随机选择 selected_ids = random.sample( candidates, min(target_count, len(candidates)) ) if candidates else [] # 转换为Agent对象 active_agents = [] for agent_id in selected_ids: try: agent = env.agent_graph.get_agent(agent_id) active_agents.append((agent_id, agent)) except Exception: pass return active_agents async def run(self, max_rounds: int = None): """运行Twitter模拟 Args: max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) """ print("=" * 60) print("OASIS Twitter模拟") print(f"配置文件: {self.config_path}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) # 加载时间配置 time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) # 计算总轮数 total_rounds = (total_hours * 60) // minutes_per_round # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") print(f"\n模拟参数:") print(f" - 总模拟时长: {total_hours}小时") print(f" - 每轮时间: {minutes_per_round}分钟") print(f" - 总轮数: {total_rounds}") if max_rounds: print(f" - 最大轮数限制: {max_rounds}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") # 创建模型 print("\n初始化LLM模型...") model = self._create_model() # 加载Agent图 print("加载Agent Profile...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): print(f"错误: Profile文件不存在: {profile_path}") return self.agent_graph = await generate_twitter_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, ) # 数据库路径 db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) print(f"已删除旧数据库: {db_path}") # 创建环境 print("创建OASIS环境...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) await self.env.reset() print("环境初始化完成\n") # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") # 执行初始事件 event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) if initial_posts: print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: agent = self.env.agent_graph.get_agent(agent_id) initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} ) except Exception as e: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") if initial_actions: await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") # 主模拟循环 print("\n开始模拟循环...") start_time = datetime.now() for round_num in range(total_rounds): # 计算当前模拟时间 simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 # 获取本轮激活的Agent active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) if not active_agents: continue # 构建动作 actions = { agent: LLMAction() for _, agent in active_agents } # 执行动作 await self.env.step(actions) # 打印进度 if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 print(f" [Day {simulated_day}, {simulated_hour:02d}:00] " f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") total_elapsed = (datetime.now() - start_time).total_seconds() print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") # 是否进入等待命令模式 if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") print("支持的命令: interview, batch_interview, close_env") print("=" * 60) self.ipc_handler.update_status("alive") # 等待命令循环(使用全局 _shutdown_event) try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() if not should_continue: break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) break # 收到退出信号 except asyncio.TimeoutError: pass except KeyboardInterrupt: print("\n收到中断信号") except asyncio.CancelledError: print("\n任务被取消") except Exception as e: print(f"\n命令处理出错: {e}") print("\n关闭环境...") # 关闭环境 self.ipc_handler.update_status("stopped") await self.env.close() print("环境已关闭") print("=" * 60) async def main(): parser = argparse.ArgumentParser(description='OASIS Twitter模拟') parser.add_argument( '--config', type=str, required=True, help='配置文件路径 (simulation_config.json)' ) parser.add_argument( '--max-rounds', type=int, default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) parser.add_argument( '--no-wait', action='store_true', default=False, help='模拟完成后立即关闭环境,不进入等待命令模式' ) args = parser.parse_args() # 在 main 函数开始时创建 shutdown 事件 global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) # 初始化日志配置(使用固定文件名,清理旧日志) simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) runner = TwitterSimulationRunner( config_path=args.config, wait_for_commands=not args.no_wait ) await runner.run(max_rounds=args.max_rounds) def setup_signal_handlers(): """ 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 让程序有机会正常清理资源(关闭数据库、环境等) """ def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" print(f"\n收到 {sig_name} 信号,正在退出...") if not _cleanup_done: _cleanup_done = True if _shutdown_event: _shutdown_event.set() else: # 重复收到信号才强制退出 print("强制退出...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": setup_signal_handlers() try: asyncio.run(main()) except KeyboardInterrupt: print("\n程序被中断") except SystemExit: pass finally: print("模拟进程已退出") ================================================ FILE: backend/scripts/test_profile_format.py ================================================ """ 测试Profile格式生成是否符合OASIS要求 验证: 1. Twitter Profile生成CSV格式 2. Reddit Profile生成JSON详细格式 """ import os import sys import json import csv import tempfile # 添加项目路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile def test_profile_formats(): """测试Profile格式""" print("=" * 60) print("OASIS Profile格式测试") print("=" * 60) # 创建测试Profile数据 test_profiles = [ OasisAgentProfile( user_id=0, user_name="test_user_123", name="Test User", bio="A test user for validation", persona="Test User is an enthusiastic participant in social discussions.", karma=1500, friend_count=100, follower_count=200, statuses_count=500, age=25, gender="male", mbti="INTJ", country="China", profession="Student", interested_topics=["Technology", "Education"], source_entity_uuid="test-uuid-123", source_entity_type="Student", ), OasisAgentProfile( user_id=1, user_name="org_official_456", name="Official Organization", bio="Official account for Organization", persona="This is an official institutional account that communicates official positions.", karma=5000, friend_count=50, follower_count=10000, statuses_count=200, profession="Organization", interested_topics=["Public Policy", "Announcements"], source_entity_uuid="test-uuid-456", source_entity_type="University", ), ] generator = OasisProfileGenerator.__new__(OasisProfileGenerator) # 使用临时目录 with tempfile.TemporaryDirectory() as temp_dir: twitter_path = os.path.join(temp_dir, "twitter_profiles.csv") reddit_path = os.path.join(temp_dir, "reddit_profiles.json") # 测试Twitter CSV格式 print("\n1. 测试Twitter Profile (CSV格式)") print("-" * 40) generator._save_twitter_csv(test_profiles, twitter_path) # 读取并验证CSV with open(twitter_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) rows = list(reader) print(f" 文件: {twitter_path}") print(f" 行数: {len(rows)}") print(f" 表头: {list(rows[0].keys())}") print(f"\n 示例数据 (第1行):") for key, value in rows[0].items(): print(f" {key}: {value}") # 验证必需字段 required_twitter_fields = ['user_id', 'user_name', 'name', 'bio', 'friend_count', 'follower_count', 'statuses_count', 'created_at'] missing = set(required_twitter_fields) - set(rows[0].keys()) if missing: print(f"\n [错误] 缺少字段: {missing}") else: print(f"\n [通过] 所有必需字段都存在") # 测试Reddit JSON格式 print("\n2. 测试Reddit Profile (JSON详细格式)") print("-" * 40) generator._save_reddit_json(test_profiles, reddit_path) # 读取并验证JSON with open(reddit_path, 'r', encoding='utf-8') as f: reddit_data = json.load(f) print(f" 文件: {reddit_path}") print(f" 条目数: {len(reddit_data)}") print(f" 字段: {list(reddit_data[0].keys())}") print(f"\n 示例数据 (第1条):") print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4)) # 验证详细格式字段 required_reddit_fields = ['realname', 'username', 'bio', 'persona'] optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics'] missing = set(required_reddit_fields) - set(reddit_data[0].keys()) if missing: print(f"\n [错误] 缺少必需字段: {missing}") else: print(f"\n [通过] 所有必需字段都存在") present_optional = set(optional_reddit_fields) & set(reddit_data[0].keys()) print(f" [信息] 可选字段: {present_optional}") print("\n" + "=" * 60) print("测试完成!") print("=" * 60) def show_expected_formats(): """显示OASIS期望的格式""" print("\n" + "=" * 60) print("OASIS 期望的Profile格式参考") print("=" * 60) print("\n1. Twitter Profile (CSV格式)") print("-" * 40) twitter_example = """user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at 0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01 1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02""" print(twitter_example) print("\n2. Reddit Profile (JSON详细格式)") print("-" * 40) reddit_example = [ { "realname": "James Miller", "username": "millerhospitality", "bio": "Passionate about hospitality & tourism.", "persona": "James is a seasoned professional in the Hospitality & Tourism industry...", "age": 40, "gender": "male", "mbti": "ESTJ", "country": "UK", "profession": "Hospitality & Tourism", "interested_topics": ["Economics", "Business"] } ] print(json.dumps(reddit_example, ensure_ascii=False, indent=2)) if __name__ == "__main__": test_profile_formats() show_expected_formats() ================================================ FILE: docker-compose.yml ================================================ services: mirofish: image: ghcr.io/666ghj/mirofish:latest # 加速镜像(如拉取缓慢可替换上方地址) # image: ghcr.nju.edu.cn/666ghj/mirofish:latest container_name: mirofish env_file: - .env ports: - "3000:3000" - "5001:5001" restart: unless-stopped volumes: - ./backend/uploads:/app/backend/uploads ================================================ FILE: frontend/.gitignore ================================================ # Logs logs *.log npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* lerna-debug.log* node_modules dist dist-ssr *.local # Editor directories and files .vscode/* !.vscode/extensions.json .idea .DS_Store *.suo *.ntvs* *.njsproj *.sln *.sw? ================================================ FILE: frontend/index.html ================================================ MiroFish - 预测万物
================================================ FILE: frontend/package.json ================================================ { "name": "frontend", "private": true, "version": "0.1.0", "type": "module", "scripts": { "dev": "vite --host", "build": "vite build", "preview": "vite preview" }, "dependencies": { "axios": "^1.13.2", "d3": "^7.9.0", "vue": "^3.5.24", "vue-router": "^4.6.3" }, "devDependencies": { "@vitejs/plugin-vue": "^6.0.1", "vite": "^7.2.4" } } ================================================ FILE: frontend/src/App.vue ================================================ ================================================ FILE: frontend/src/api/graph.js ================================================ import service, { requestWithRetry } from './index' /** * 生成本体(上传文档和模拟需求) * @param {Object} data - 包含files, simulation_requirement, project_name等 * @returns {Promise} */ export function generateOntology(formData) { return requestWithRetry(() => service({ url: '/api/graph/ontology/generate', method: 'post', data: formData, headers: { 'Content-Type': 'multipart/form-data' } }) ) } /** * 构建图谱 * @param {Object} data - 包含project_id, graph_name等 * @returns {Promise} */ export function buildGraph(data) { return requestWithRetry(() => service({ url: '/api/graph/build', method: 'post', data }) ) } /** * 查询任务状态 * @param {String} taskId - 任务ID * @returns {Promise} */ export function getTaskStatus(taskId) { return service({ url: `/api/graph/task/${taskId}`, method: 'get' }) } /** * 获取图谱数据 * @param {String} graphId - 图谱ID * @returns {Promise} */ export function getGraphData(graphId) { return service({ url: `/api/graph/data/${graphId}`, method: 'get' }) } /** * 获取项目信息 * @param {String} projectId - 项目ID * @returns {Promise} */ export function getProject(projectId) { return service({ url: `/api/graph/project/${projectId}`, method: 'get' }) } ================================================ FILE: frontend/src/api/index.js ================================================ import axios from 'axios' // 创建axios实例 const service = axios.create({ baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:5001', timeout: 300000, // 5分钟超时(本体生成可能需要较长时间) headers: { 'Content-Type': 'application/json' } }) // 请求拦截器 service.interceptors.request.use( config => { return config }, error => { console.error('Request error:', error) return Promise.reject(error) } ) // 响应拦截器(容错重试机制) service.interceptors.response.use( response => { const res = response.data // 如果返回的状态码不是success,则抛出错误 if (!res.success && res.success !== undefined) { console.error('API Error:', res.error || res.message || 'Unknown error') return Promise.reject(new Error(res.error || res.message || 'Error')) } return res }, error => { console.error('Response error:', error) // 处理超时 if (error.code === 'ECONNABORTED' && error.message.includes('timeout')) { console.error('Request timeout') } // 处理网络错误 if (error.message === 'Network Error') { console.error('Network error - please check your connection') } return Promise.reject(error) } ) // 带重试的请求函数 export const requestWithRetry = async (requestFn, maxRetries = 3, delay = 1000) => { for (let i = 0; i < maxRetries; i++) { try { return await requestFn() } catch (error) { if (i === maxRetries - 1) throw error console.warn(`Request failed, retrying (${i + 1}/${maxRetries})...`) await new Promise(resolve => setTimeout(resolve, delay * Math.pow(2, i))) } } } export default service ================================================ FILE: frontend/src/api/report.js ================================================ import service, { requestWithRetry } from './index' /** * 开始报告生成 * @param {Object} data - { simulation_id, force_regenerate? } */ export const generateReport = (data) => { return requestWithRetry(() => service.post('/api/report/generate', data), 3, 1000) } /** * 获取报告生成状态 * @param {string} reportId */ export const getReportStatus = (reportId) => { return service.get(`/api/report/generate/status`, { params: { report_id: reportId } }) } /** * 获取 Agent 日志(增量) * @param {string} reportId * @param {number} fromLine - 从第几行开始获取 */ export const getAgentLog = (reportId, fromLine = 0) => { return service.get(`/api/report/${reportId}/agent-log`, { params: { from_line: fromLine } }) } /** * 获取控制台日志(增量) * @param {string} reportId * @param {number} fromLine - 从第几行开始获取 */ export const getConsoleLog = (reportId, fromLine = 0) => { return service.get(`/api/report/${reportId}/console-log`, { params: { from_line: fromLine } }) } /** * 获取报告详情 * @param {string} reportId */ export const getReport = (reportId) => { return service.get(`/api/report/${reportId}`) } /** * 与 Report Agent 对话 * @param {Object} data - { simulation_id, message, chat_history? } */ export const chatWithReport = (data) => { return requestWithRetry(() => service.post('/api/report/chat', data), 3, 1000) } ================================================ FILE: frontend/src/api/simulation.js ================================================ import service, { requestWithRetry } from './index' /** * 创建模拟 * @param {Object} data - { project_id, graph_id?, enable_twitter?, enable_reddit? } */ export const createSimulation = (data) => { return requestWithRetry(() => service.post('/api/simulation/create', data), 3, 1000) } /** * 准备模拟环境(异步任务) * @param {Object} data - { simulation_id, entity_types?, use_llm_for_profiles?, parallel_profile_count?, force_regenerate? } */ export const prepareSimulation = (data) => { return requestWithRetry(() => service.post('/api/simulation/prepare', data), 3, 1000) } /** * 查询准备任务进度 * @param {Object} data - { task_id?, simulation_id? } */ export const getPrepareStatus = (data) => { return service.post('/api/simulation/prepare/status', data) } /** * 获取模拟状态 * @param {string} simulationId */ export const getSimulation = (simulationId) => { return service.get(`/api/simulation/${simulationId}`) } /** * 获取模拟的 Agent Profiles * @param {string} simulationId * @param {string} platform - 'reddit' | 'twitter' */ export const getSimulationProfiles = (simulationId, platform = 'reddit') => { return service.get(`/api/simulation/${simulationId}/profiles`, { params: { platform } }) } /** * 实时获取生成中的 Agent Profiles * @param {string} simulationId * @param {string} platform - 'reddit' | 'twitter' */ export const getSimulationProfilesRealtime = (simulationId, platform = 'reddit') => { return service.get(`/api/simulation/${simulationId}/profiles/realtime`, { params: { platform } }) } /** * 获取模拟配置 * @param {string} simulationId */ export const getSimulationConfig = (simulationId) => { return service.get(`/api/simulation/${simulationId}/config`) } /** * 实时获取生成中的模拟配置 * @param {string} simulationId * @returns {Promise} 返回配置信息,包含元数据和配置内容 */ export const getSimulationConfigRealtime = (simulationId) => { return service.get(`/api/simulation/${simulationId}/config/realtime`) } /** * 列出所有模拟 * @param {string} projectId - 可选,按项目ID过滤 */ export const listSimulations = (projectId) => { const params = projectId ? { project_id: projectId } : {} return service.get('/api/simulation/list', { params }) } /** * 启动模拟 * @param {Object} data - { simulation_id, platform?, max_rounds?, enable_graph_memory_update? } */ export const startSimulation = (data) => { return requestWithRetry(() => service.post('/api/simulation/start', data), 3, 1000) } /** * 停止模拟 * @param {Object} data - { simulation_id } */ export const stopSimulation = (data) => { return service.post('/api/simulation/stop', data) } /** * 获取模拟运行实时状态 * @param {string} simulationId */ export const getRunStatus = (simulationId) => { return service.get(`/api/simulation/${simulationId}/run-status`) } /** * 获取模拟运行详细状态(包含最近动作) * @param {string} simulationId */ export const getRunStatusDetail = (simulationId) => { return service.get(`/api/simulation/${simulationId}/run-status/detail`) } /** * 获取模拟中的帖子 * @param {string} simulationId * @param {string} platform - 'reddit' | 'twitter' * @param {number} limit - 返回数量 * @param {number} offset - 偏移量 */ export const getSimulationPosts = (simulationId, platform = 'reddit', limit = 50, offset = 0) => { return service.get(`/api/simulation/${simulationId}/posts`, { params: { platform, limit, offset } }) } /** * 获取模拟时间线(按轮次汇总) * @param {string} simulationId * @param {number} startRound - 起始轮次 * @param {number} endRound - 结束轮次 */ export const getSimulationTimeline = (simulationId, startRound = 0, endRound = null) => { const params = { start_round: startRound } if (endRound !== null) { params.end_round = endRound } return service.get(`/api/simulation/${simulationId}/timeline`, { params }) } /** * 获取Agent统计信息 * @param {string} simulationId */ export const getAgentStats = (simulationId) => { return service.get(`/api/simulation/${simulationId}/agent-stats`) } /** * 获取模拟动作历史 * @param {string} simulationId * @param {Object} params - { limit, offset, platform, agent_id, round_num } */ export const getSimulationActions = (simulationId, params = {}) => { return service.get(`/api/simulation/${simulationId}/actions`, { params }) } /** * 关闭模拟环境(优雅退出) * @param {Object} data - { simulation_id, timeout? } */ export const closeSimulationEnv = (data) => { return service.post('/api/simulation/close-env', data) } /** * 获取模拟环境状态 * @param {Object} data - { simulation_id } */ export const getEnvStatus = (data) => { return service.post('/api/simulation/env-status', data) } /** * 批量采访 Agent * @param {Object} data - { simulation_id, interviews: [{ agent_id, prompt }] } */ export const interviewAgents = (data) => { return requestWithRetry(() => service.post('/api/simulation/interview/batch', data), 3, 1000) } /** * 获取历史模拟列表(带项目详情) * 用于首页历史项目展示 * @param {number} limit - 返回数量限制 */ export const getSimulationHistory = (limit = 20) => { return service.get('/api/simulation/history', { params: { limit } }) } ================================================ FILE: frontend/src/components/GraphPanel.vue ================================================ ================================================ FILE: frontend/src/components/HistoryDatabase.vue ================================================ ================================================ FILE: frontend/src/components/Step1GraphBuild.vue ================================================ ================================================ FILE: frontend/src/components/Step2EnvSetup.vue ================================================ ================================================ FILE: frontend/src/components/Step3Simulation.vue ================================================ ================================================ FILE: frontend/src/components/Step4Report.vue ================================================ ================================================ FILE: frontend/src/components/Step5Interaction.vue ================================================ ================================================ FILE: frontend/src/main.js ================================================ import { createApp } from 'vue' import App from './App.vue' import router from './router' const app = createApp(App) app.use(router) app.mount('#app') ================================================ FILE: frontend/src/router/index.js ================================================ import { createRouter, createWebHistory } from 'vue-router' import Home from '../views/Home.vue' import Process from '../views/MainView.vue' import SimulationView from '../views/SimulationView.vue' import SimulationRunView from '../views/SimulationRunView.vue' import ReportView from '../views/ReportView.vue' import InteractionView from '../views/InteractionView.vue' const routes = [ { path: '/', name: 'Home', component: Home }, { path: '/process/:projectId', name: 'Process', component: Process, props: true }, { path: '/simulation/:simulationId', name: 'Simulation', component: SimulationView, props: true }, { path: '/simulation/:simulationId/start', name: 'SimulationRun', component: SimulationRunView, props: true }, { path: '/report/:reportId', name: 'Report', component: ReportView, props: true }, { path: '/interaction/:reportId', name: 'Interaction', component: InteractionView, props: true } ] const router = createRouter({ history: createWebHistory(), routes }) export default router ================================================ FILE: frontend/src/store/pendingUpload.js ================================================ /** * 临时存储待上传的文件和需求 * 用于首页点击启动引擎后立即跳转,在Process页面再进行API调用 */ import { reactive } from 'vue' const state = reactive({ files: [], simulationRequirement: '', isPending: false }) export function setPendingUpload(files, requirement) { state.files = files state.simulationRequirement = requirement state.isPending = true } export function getPendingUpload() { return { files: state.files, simulationRequirement: state.simulationRequirement, isPending: state.isPending } } export function clearPendingUpload() { state.files = [] state.simulationRequirement = '' state.isPending = false } export default state ================================================ FILE: frontend/src/views/Home.vue ================================================ ================================================ FILE: frontend/src/views/InteractionView.vue ================================================ ================================================ FILE: frontend/src/views/MainView.vue ================================================ ================================================ FILE: frontend/src/views/Process.vue ================================================ ================================================ FILE: frontend/src/views/ReportView.vue ================================================ ================================================ FILE: frontend/src/views/SimulationRunView.vue ================================================ ================================================ FILE: frontend/src/views/SimulationView.vue ================================================ ================================================ FILE: frontend/vite.config.js ================================================ import { defineConfig } from 'vite' import vue from '@vitejs/plugin-vue' // https://vite.dev/config/ export default defineConfig({ plugins: [vue()], server: { port: 3000, open: true, proxy: { '/api': { target: 'http://localhost:5001', changeOrigin: true, secure: false } } } }) ================================================ FILE: package.json ================================================ { "name": "mirofish", "version": "0.1.0", "description": "MiroFish - 简洁通用的群体智能引擎,预测万物", "scripts": { "setup": "npm install && cd frontend && npm install", "setup:backend": "cd backend && uv sync", "setup:all": "npm run setup && npm run setup:backend", "dev": "concurrently --kill-others -n \"backend,frontend\" -c \"green,cyan\" \"npm run backend\" \"npm run frontend\"", "backend": "cd backend && uv run python run.py", "frontend": "cd frontend && npm run dev", "build": "cd frontend && npm run build" }, "devDependencies": { "concurrently": "^9.1.2" }, "engines": { "node": ">=18.0.0" }, "license": "AGPL-3.0" }