Article title
This is the body of the article.
|
|
|
## 🦾 Démonstration
### 🛠️ Flux de Travail Standard de l'Assistant
🧩 Ingénieur Full-Stack |
🗂️ Gestion des Logs & Planification |
🔎 Recherche Web & Apprentissage |
|---|---|---|
|
|
|
| Développer • Déployer • Mettre à l'échelle | Planifier • Automatiser • Mémoriser | Découvrir • Analyser • Tendances |
### 🐜 Déploiement Innovant à Faible Empreinte
PicoClaw peut être déployé sur pratiquement n'importe quel appareil Linux !
- 9,9$ [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) version E (Ethernet) ou W (WiFi6), pour un Assistant Domotique Minimaliste
- 30~$50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou 100$ [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) pour la Maintenance Automatisée de Serveurs
- 50$ [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou 100$ [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) pour la Surveillance Intelligente
================================================
FILE: README.id.md
================================================
|
|
|
## 🦾 Demonstrasi
### 🛠️ Alur Kerja Asisten Standar
🧩 Full-Stack Engineer |
🗂️ Pencatatan & Manajemen Perencanaan |
🔎 Pencarian Web & Pembelajaran |
|---|---|---|
|
|
|
| Develop • Deploy • Scale | Jadwal • Otomasi • Memori | Penemuan • Wawasan • Tren |
### 🐜 Deploy Inovatif dengan Footprint Rendah
PicoClaw dapat di-deploy di hampir semua perangkat Linux!
- $9,9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versi E(Ethernet) atau W(WiFi6), untuk Home Assistant Minimal
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), atau $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) untuk Pemeliharaan Server Otomatis
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) atau $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) untuk Pemantauan Cerdas
================================================
FILE: README.it.md
================================================
|
|
|
## 🦾 Dimostrazione
### 🛠️ Flussi di Lavoro Standard dell'Assistente
🧩 Ingegnere Full-Stack |
🗂️ Gestione Log & Pianificazione |
🔎 Ricerca Web & Apprendimento |
|---|---|---|
|
|
|
| Sviluppa • Distribuisci • Scala | Pianifica • Automatizza • Memorizza | Scopri • Analizza • Tendenze |
### 🐜 Deploy Innovativo a Bassa Impronta
PicoClaw può essere distribuito su quasi qualsiasi dispositivo Linux!
- $9,9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versione E (Ethernet) o W (WiFi6), per un Assistente Domotico Minimale
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), o $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) per la Manutenzione Automatizzata dei Server
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) o $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) per il Monitoraggio Intelligente
================================================
FILE: README.ja.md
================================================
|
|
|
## 🦾 デモンストレーション
### 🛠️ スタンダードアシスタントワークフロー
🧩 フルスタックエンジニア |
🗂️ ログ&計画管理 |
🔎 Web 検索&学習 |
|---|---|---|
|
|
|
| 開発 · デプロイ · スケール | スケジュール · 自動化 · メモリ | 発見 · インサイト · トレンド |
### 🐜 革新的な省フットプリントデプロイ
PicoClaw はほぼすべての Linux デバイスにデプロイできます!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) または W(WiFi6) バージョン、最小ホームアシスタントに
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html) または $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) サーバー自動メンテナンスに
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) または $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) スマート監視に
================================================
FILE: README.md
================================================
|
|
|
## 🦾 Demonstration
### 🛠️ Standard Assistant Workflows
🧩 Full-Stack Engineer |
🗂️ Logging & Planning Management |
🔎 Web Search & Learning |
|---|---|---|
|
|
|
| Develop • Deploy • Scale | Schedule • Automate • Memory | Discovery • Insights • Trends |
### 🐜 Innovative Low-Footprint Deploy
PicoClaw can be deployed on almost any Linux device!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
================================================
FILE: README.pt-br.md
================================================
|
|
|
## 🦾 Demonstração
### 🛠️ Fluxos de Trabalho Padrão do Assistente
🧩 Engenharia Full-Stack |
🗂️ Gerenciamento de Logs & Planejamento |
🔎 Busca Web & Aprendizado |
|---|---|---|
|
|
|
| Desenvolver • Implantar • Escalar | Agendar • Automatizar • Memorizar | Descobrir • Analisar • Tendências |
### 🐜 Implantação Inovadora com Baixo Consumo
O PicoClaw pode ser implantado em praticamente qualquer dispositivo Linux!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versão E(Ethernet) ou W(WiFi6), para Assistente Doméstico Minimalista
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) para Manutenção Automatizada de Servidores
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) para Monitoramento Inteligente
================================================
FILE: README.vi.md
================================================
|
|
|
## 🦾 Demo
### 🛠️ Quy trình trợ lý tiêu chuẩn
🧩 Lập trình Full-Stack |
🗂️ Quản lý Nhật ký & Kế hoạch |
🔎 Tìm kiếm Web & Học hỏi |
|---|---|---|
|
|
|
| Phát triển • Triển khai • Mở rộng | Lên lịch • Tự động hóa • Ghi nhớ | Khám phá • Phân tích • Xu hướng |
### 🐜 Triển khai sáng tạo trên phần cứng tối thiểu
PicoClaw có thể triển khai trên hầu hết mọi thiết bị Linux!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) phiên bản E(Ethernet) hoặc W(WiFi6), dùng làm Trợ lý Gia đình tối giản
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), hoặc $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) dùng cho quản trị Server tự động
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) hoặc $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) dùng cho Giám sát thông minh
================================================
FILE: README.zh.md
================================================
|
|
|
## 🦾 演示
### 🛠️ 标准助手工作流
🧩 全栈工程师模式 |
🗂️ 日志与规划管理 |
🔎 网络搜索与学习 |
|---|---|---|
|
|
|
| 开发 • 部署 • 扩展 | 日程 • 自动化 • 记忆 | 发现 • 洞察 • 趋势 |
### 🐜 创新的低占用部署
PicoClaw 几乎可以部署在任何 Linux 设备上!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控
================================================
FILE: ROADMAP.md
================================================
# 🦐 PicoClaw Roadmap
> **Vision**: To build the ultimate lightweight, secure, and fully autonomous AI Agent infrastructure.automate the mundane, unleash your creativity
---
## 🚀 1. Core Optimization: Extreme Lightweight
*Our defining characteristic. We fight software bloat to ensure PicoClaw runs smoothly on the smallest embedded devices.*
* [**Memory Footprint Reduction**](https://github.com/sipeed/picoclaw/issues/346)
* **Goal**: Run smoothly on 64MB RAM embedded boards (e.g., low-end RISC-V SBCs) with the core process consuming < 20MB.
* **Context**: RAM is expensive and scarce on edge devices. Memory optimization takes precedence over storage size.
* **Action**: Analyze memory growth between releases, remove redundant dependencies, and optimize data structures.
## 🛡️ 2. Security Hardening: Defense in Depth
*Paying off early technical debt. We invite security experts to help build a "Secure-by-Default" agent.*
* **Input Defense & Permission Control**
* **Prompt Injection Defense**: Harden JSON extraction logic to prevent LLM manipulation.
* **Tool Abuse Prevention**: Strict parameter validation to ensure generated commands stay within safe boundaries.
* **SSRF Protection**: Built-in blocklists for network tools to prevent accessing internal IPs (LAN/Metadata services).
* **Sandboxing & Isolation**
* **Filesystem Sandbox**: Restrict file R/W operations to specific directories only.
* **Context Isolation**: Prevent data leakage between different user sessions or channels.
* **Privacy Redaction**: Auto-redact sensitive info (API Keys, PII) from logs and standard outputs.
* **Authentication & Secrets**
* **Crypto Upgrade**: Adopt modern algorithms like `ChaCha20-Poly1305` for secret storage.
* **OAuth 2.0 Flow**: Deprecate hardcoded API keys in the CLI; move to secure OAuth flows.
## 🔌 3. Connectivity: Protocol-First Architecture
*Connect every model, reach every platform.*
* **Provider**
* [**Architecture Upgrade**](https://github.com/sipeed/picoclaw/issues/283): Refactor from "Vendor-based" to "Protocol-based" classification (e.g., OpenAI-compatible, Ollama-compatible). *(Status: In progress by @Daming, ETA 5 days)*
* **Local Models**: Deep integration with **Ollama**, **vLLM**, **LM Studio**, and **Mistral** (local inference).
* **Online Models**: Continued support for frontier closed-source models.
* **Channel**
* **IM Matrix**: QQ, WeChat (Work), DingTalk, Feishu (Lark), Telegram, Discord, WhatsApp, LINE, Slack, Email, KOOK, Signal, ...
* **Standards**: Support for the **OneBot** protocol.
* [**attachment**](https://github.com/sipeed/picoclaw/issues/348): Native handling of images, audio, and video attachments.
* **Skill Marketplace**
* [**Discovery skills**](https://github.com/sipeed/picoclaw/issues/287): Implement `find_skill` to automatically discover and install skills from the [GitHub Skills Repo] or other registries.
## 🧠 4. Advanced Capabilities: From Chatbot to Agentic AI
*Beyond conversation—focusing on action and collaboration.*
* **Operations**
* [**MCP Support**](https://github.com/sipeed/picoclaw/issues/290): Native support for the **Model Context Protocol (MCP)**.
* [**Browser Automation**](https://github.com/sipeed/picoclaw/issues/293): Headless browser control via CDP (Chrome DevTools Protocol) or ActionBook.
* [**Mobile Operation**](https://github.com/sipeed/picoclaw/issues/292): Android device control (similar to BotDrop).
* **Multi-Agent Collaboration**
* [**Basic Multi-Agent**](https://github.com/sipeed/picoclaw/issues/294) implement
* [**Model Routing**](https://github.com/sipeed/picoclaw/issues/295): "Smart Routing" — dispatch simple tasks to small/local models (fast/cheap) and complex tasks to SOTA models (smart).
* [**Swarm Mode**](https://github.com/sipeed/picoclaw/issues/284): Collaboration between multiple PicoClaw instances on the same network.
* [**AIEOS**](https://github.com/sipeed/picoclaw/issues/296): Exploring AI-Native Operating System interaction paradigms.
## 📚 5. Developer Experience (DevEx) & Documentation
*Lowering the barrier to entry so anyone can deploy in minutes.*
* [**QuickGuide (Zero-Config Start)**](https://github.com/sipeed/picoclaw/issues/350)
* Interactive CLI Wizard: If launched without config, automatically detect the environment and guide the user through Token/Network setup step-by-step.
* **Comprehensive Documentation**
* **Platform Guides**: Dedicated guides for Windows, macOS, Linux, and Android.
* **Step-by-Step Tutorials**: "Babysitter-level" guides for configuring Providers and Channels.
* **AI-Assisted Docs**: Using AI to auto-generate API references and code comments (with human verification to prevent hallucinations).
## 🤖 6. Engineering: AI-Powered Open Source
*Born from Vibe Coding, we continue to use AI to accelerate development.*
* **AI-Enhanced CI/CD**
* Integrate AI for automated Code Review, Linting, and PR Labeling.
* **Bot Noise Reduction**: Optimize bot interactions to keep PR timelines clean.
* **Issue Triage**: AI agents to analyze incoming issues and suggest preliminary fixes.
## 🎨 7. Brand & Community
* [**Logo Design**](https://github.com/sipeed/picoclaw/issues/297): We are looking for a **Mantis Shrimp (Stomatopoda)** logo design!
* *Concept*: Needs to reflect "Small but Mighty" and "Lightning Fast Strikes."
---
### 🤝 Call for Contributions
We welcome community contributions to any item on this roadmap! Please comment on the relevant Issue or submit a PR. Let's build the best Edge AI Agent together!
================================================
FILE: cmd/picoclaw/internal/agent/command.go
================================================
package agent
import (
"github.com/spf13/cobra"
)
func NewAgentCommand() *cobra.Command {
var (
message string
sessionKey string
model string
debug bool
)
cmd := &cobra.Command{
Use: "agent",
Short: "Interact with the agent directly",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return agentCmd(message, sessionKey, model, debug)
},
}
cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
cmd.Flags().StringVarP(&message, "message", "m", "", "Send a single message (non-interactive mode)")
cmd.Flags().StringVarP(&sessionKey, "session", "s", "cli:default", "Session key")
cmd.Flags().StringVarP(&model, "model", "", "", "Model to use")
return cmd
}
================================================
FILE: cmd/picoclaw/internal/agent/command_test.go
================================================
package agent
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAgentCommand(t *testing.T) {
cmd := NewAgentCommand()
require.NotNil(t, cmd)
assert.Equal(t, "agent", cmd.Use)
assert.Equal(t, "Interact with the agent directly", cmd.Short)
assert.Len(t, cmd.Aliases, 0)
assert.False(t, cmd.HasSubCommands())
assert.Nil(t, cmd.Run)
assert.NotNil(t, cmd.RunE)
assert.Nil(t, cmd.PersistentPreRun)
assert.Nil(t, cmd.PersistentPostRun)
assert.True(t, cmd.HasFlags())
assert.NotNil(t, cmd.Flags().Lookup("debug"))
assert.NotNil(t, cmd.Flags().Lookup("message"))
assert.NotNil(t, cmd.Flags().Lookup("session"))
assert.NotNil(t, cmd.Flags().Lookup("model"))
}
================================================
FILE: cmd/picoclaw/internal/agent/helpers.go
================================================
package agent
import (
"bufio"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ergochat/readline"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
func agentCmd(message, sessionKey, model string, debug bool) error {
if sessionKey == "" {
sessionKey = "cli:default"
}
if debug {
logger.SetLevel(logger.DEBUG)
fmt.Println("🔍 Debug mode enabled")
}
cfg, err := internal.LoadConfig()
if err != nil {
return fmt.Errorf("error loading config: %w", err)
}
if model != "" {
cfg.Agents.Defaults.ModelName = model
}
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}
// Use the resolved model ID from provider creation
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
// Print agent startup info (only for interactive mode)
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("agent", "Agent initialized",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
if message != "" {
ctx := context.Background()
response, err := agentLoop.ProcessDirect(ctx, message, sessionKey)
if err != nil {
return fmt.Errorf("error processing message: %w", err)
}
fmt.Printf("\n%s %s\n", internal.Logo, response)
return nil
}
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", internal.Logo)
interactiveMode(agentLoop, sessionKey)
return nil
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
prompt := fmt.Sprintf("%s You: ", internal.Logo)
rl, err := readline.NewEx(&readline.Config{
Prompt: prompt,
HistoryFile: filepath.Join(os.TempDir(), ".picoclaw_history"),
HistoryLimit: 100,
InterruptPrompt: "^C",
EOFPrompt: "exit",
})
if err != nil {
fmt.Printf("Error initializing readline: %v\n", err)
fmt.Println("Falling back to simple input mode...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
for {
line, err := rl.Readline()
if err != nil {
if err == readline.ErrInterrupt || err == io.EOF {
fmt.Println("\nGoodbye!")
return
}
fmt.Printf("Error reading input: %v\n", err)
continue
}
input := strings.TrimSpace(line)
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("Goodbye!")
return
}
ctx := context.Background()
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
if err != nil {
fmt.Printf("Error: %v\n", err)
continue
}
fmt.Printf("\n%s %s\n\n", internal.Logo, response)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := bufio.NewReader(os.Stdin)
for {
fmt.Print(fmt.Sprintf("%s You: ", internal.Logo))
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
fmt.Println("\nGoodbye!")
return
}
fmt.Printf("Error reading input: %v\n", err)
continue
}
input := strings.TrimSpace(line)
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("Goodbye!")
return
}
ctx := context.Background()
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
if err != nil {
fmt.Printf("Error: %v\n", err)
continue
}
fmt.Printf("\n%s %s\n\n", internal.Logo, response)
}
}
================================================
FILE: cmd/picoclaw/internal/auth/command.go
================================================
package auth
import "github.com/spf13/cobra"
func NewAuthCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "auth",
Short: "Manage authentication (login, logout, status)",
RunE: func(cmd *cobra.Command, _ []string) error {
return cmd.Help()
},
}
cmd.AddCommand(
newLoginCommand(),
newLogoutCommand(),
newStatusCommand(),
newModelsCommand(),
)
return cmd
}
================================================
FILE: cmd/picoclaw/internal/auth/command_test.go
================================================
package auth
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAuthCommand(t *testing.T) {
cmd := NewAuthCommand()
require.NotNil(t, cmd)
assert.Equal(t, "auth", cmd.Use)
assert.Equal(t, "Manage authentication (login, logout, status)", cmd.Short)
assert.Len(t, cmd.Aliases, 0)
assert.Nil(t, cmd.Run)
assert.NotNil(t, cmd.RunE)
assert.Nil(t, cmd.PersistentPreRun)
assert.Nil(t, cmd.PersistentPostRun)
assert.False(t, cmd.HasFlags())
assert.True(t, cmd.HasSubCommands())
allowedCommands := []string{
"login",
"logout",
"status",
"models",
}
subcommands := cmd.Commands()
assert.Len(t, subcommands, len(allowedCommands))
for _, subcmd := range subcommands {
found := slices.Contains(allowedCommands, subcmd.Name())
assert.True(t, found, "unexpected subcommand %q", subcmd.Name())
assert.Len(t, subcmd.Aliases, 0)
assert.False(t, subcmd.Hidden)
assert.False(t, subcmd.HasSubCommands())
assert.Nil(t, subcmd.Run)
assert.NotNil(t, subcmd.RunE)
assert.Nil(t, subcmd.PersistentPreRun)
assert.Nil(t, subcmd.PersistentPostRun)
}
}
================================================
FILE: cmd/picoclaw/internal/auth/helpers.go
================================================
package auth
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
const (
supportedProvidersMsg = "supported providers: openai, anthropic, google-antigravity"
defaultAnthropicModel = "claude-sonnet-4.6"
)
func authLoginCmd(provider string, useDeviceCode bool, useOauth bool) error {
switch provider {
case "openai":
return authLoginOpenAI(useDeviceCode)
case "anthropic":
return authLoginAnthropic(useOauth)
case "google-antigravity", "antigravity":
return authLoginGoogleAntigravity()
default:
return fmt.Errorf("unsupported provider: %s (%s)", provider, supportedProvidersMsg)
}
}
func authLoginOpenAI(useDeviceCode bool) error {
cfg := auth.OpenAIOAuthConfig()
var cred *auth.AuthCredential
var err error
if useDeviceCode {
cred, err = auth.LoginDeviceCode(cfg)
} else {
cred, err = auth.LoginBrowser(cfg)
}
if err != nil {
return fmt.Errorf("login failed: %w", err)
}
if err = auth.SetCredential("openai", cred); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
appCfg, err := internal.LoadConfig()
if err == nil {
// Update Providers (legacy format)
appCfg.Providers.OpenAI.AuthMethod = "oauth"
// Update or add openai in ModelList
foundOpenAI := false
for i := range appCfg.ModelList {
if isOpenAIModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = "oauth"
foundOpenAI = true
break
}
}
// If no openai in ModelList, add it
if !foundOpenAI {
appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{
ModelName: "gpt-5.4",
Model: "openai/gpt-5.4",
AuthMethod: "oauth",
})
}
// Update default model to use OpenAI
appCfg.Agents.Defaults.ModelName = "gpt-5.4"
if err = config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
return fmt.Errorf("could not update config: %w", err)
}
}
fmt.Println("Login successful!")
if cred.AccountID != "" {
fmt.Printf("Account: %s\n", cred.AccountID)
}
fmt.Println("Default model set to: gpt-5.4")
return nil
}
func authLoginGoogleAntigravity() error {
cfg := auth.GoogleAntigravityOAuthConfig()
cred, err := auth.LoginBrowser(cfg)
if err != nil {
return fmt.Errorf("login failed: %w", err)
}
cred.Provider = "google-antigravity"
// Fetch user email from Google userinfo
email, err := fetchGoogleUserEmail(cred.AccessToken)
if err != nil {
fmt.Printf("Warning: could not fetch email: %v\n", err)
} else {
cred.Email = email
fmt.Printf("Email: %s\n", email)
}
// Fetch Cloud Code Assist project ID
projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken)
if err != nil {
fmt.Printf("Warning: could not fetch project ID: %v\n", err)
fmt.Println("You may need Google Cloud Code Assist enabled on your account.")
} else {
cred.ProjectID = projectID
fmt.Printf("Project: %s\n", projectID)
}
if err = auth.SetCredential("google-antigravity", cred); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
appCfg, err := internal.LoadConfig()
if err == nil {
// Update Providers (legacy format, for backward compatibility)
appCfg.Providers.Antigravity.AuthMethod = "oauth"
// Update or add antigravity in ModelList
foundAntigravity := false
for i := range appCfg.ModelList {
if isAntigravityModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = "oauth"
foundAntigravity = true
break
}
}
// If no antigravity in ModelList, add it
if !foundAntigravity {
appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{
ModelName: "gemini-flash",
Model: "antigravity/gemini-3-flash",
AuthMethod: "oauth",
})
}
// Update default model
appCfg.Agents.Defaults.ModelName = "gemini-flash"
if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Println("\n✓ Google Antigravity login successful!")
fmt.Println("Default model set to: gemini-flash")
fmt.Println("Try it: picoclaw agent -m \"Hello world\"")
return nil
}
func authLoginAnthropic(useOauth bool) error {
if useOauth {
return authLoginAnthropicSetupToken()
}
fmt.Println("Anthropic login method:")
fmt.Println(" 1) Setup token (from `claude setup-token`) (Recommended)")
fmt.Println(" 2) API key (from console.anthropic.com)")
scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print("Choose [1]: ")
choice := "1"
if scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text != "" {
choice = text
}
}
switch choice {
case "1":
return authLoginAnthropicSetupToken()
case "2":
return authLoginPasteToken("anthropic")
default:
fmt.Printf("Invalid choice: %s. Please enter 1 or 2.\n", choice)
}
}
}
func authLoginAnthropicSetupToken() error {
cred, err := auth.LoginSetupToken(os.Stdin)
if err != nil {
return fmt.Errorf("login failed: %w", err)
}
if err = auth.SetCredential("anthropic", cred); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
appCfg, err := internal.LoadConfig()
if err == nil {
appCfg.Providers.Anthropic.AuthMethod = "oauth"
found := false
for i := range appCfg.ModelList {
if isAnthropicModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = "oauth"
found = true
break
}
}
if !found {
appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{
ModelName: defaultAnthropicModel,
Model: "anthropic/" + defaultAnthropicModel,
AuthMethod: "oauth",
})
// Only set default model if user has no default configured yet
if appCfg.Agents.Defaults.GetModelName() == "" {
appCfg.Agents.Defaults.ModelName = defaultAnthropicModel
}
}
if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
return fmt.Errorf("could not update config: %w", err)
}
}
fmt.Println("Setup token saved for Anthropic!")
return nil
}
func fetchGoogleUserEmail(accessToken string) (string, error) {
req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading userinfo response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("userinfo request failed: %s", string(body))
}
var userInfo struct {
Email string `json:"email"`
}
if err := json.Unmarshal(body, &userInfo); err != nil {
return "", err
}
return userInfo.Email, nil
}
func authLoginPasteToken(provider string) error {
cred, err := auth.LoginPasteToken(provider, os.Stdin)
if err != nil {
return fmt.Errorf("login failed: %w", err)
}
if err = auth.SetCredential(provider, cred); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
appCfg, err := internal.LoadConfig()
if err == nil {
switch provider {
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = "token"
// Update ModelList
found := false
for i := range appCfg.ModelList {
if isAnthropicModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = "token"
found = true
break
}
}
if !found {
appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{
ModelName: defaultAnthropicModel,
Model: "anthropic/" + defaultAnthropicModel,
AuthMethod: "token",
})
appCfg.Agents.Defaults.ModelName = defaultAnthropicModel
}
case "openai":
appCfg.Providers.OpenAI.AuthMethod = "token"
// Update ModelList
found := false
for i := range appCfg.ModelList {
if isOpenAIModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = "token"
found = true
break
}
}
if !found {
appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{
ModelName: "gpt-5.4",
Model: "openai/gpt-5.4",
AuthMethod: "token",
})
}
// Update default model
appCfg.Agents.Defaults.ModelName = "gpt-5.4"
}
if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
return fmt.Errorf("could not update config: %w", err)
}
}
fmt.Printf("Token saved for %s!\n", provider)
if appCfg != nil {
fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.GetModelName())
}
return nil
}
func authLogoutCmd(provider string) error {
if provider != "" {
if err := auth.DeleteCredential(provider); err != nil {
return fmt.Errorf("failed to remove credentials: %w", err)
}
appCfg, err := internal.LoadConfig()
if err == nil {
// Clear AuthMethod in ModelList
for i := range appCfg.ModelList {
switch provider {
case "openai":
if isOpenAIModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = ""
}
case "anthropic":
if isAnthropicModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = ""
}
case "google-antigravity", "antigravity":
if isAntigravityModel(appCfg.ModelList[i].Model) {
appCfg.ModelList[i].AuthMethod = ""
}
}
}
// Clear AuthMethod in Providers (legacy)
switch provider {
case "openai":
appCfg.Providers.OpenAI.AuthMethod = ""
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = ""
case "google-antigravity", "antigravity":
appCfg.Providers.Antigravity.AuthMethod = ""
}
config.SaveConfig(internal.GetConfigPath(), appCfg)
}
fmt.Printf("Logged out from %s\n", provider)
return nil
}
if err := auth.DeleteAllCredentials(); err != nil {
return fmt.Errorf("failed to remove credentials: %w", err)
}
appCfg, err := internal.LoadConfig()
if err == nil {
// Clear all AuthMethods in ModelList
for i := range appCfg.ModelList {
appCfg.ModelList[i].AuthMethod = ""
}
// Clear all AuthMethods in Providers (legacy)
appCfg.Providers.OpenAI.AuthMethod = ""
appCfg.Providers.Anthropic.AuthMethod = ""
appCfg.Providers.Antigravity.AuthMethod = ""
config.SaveConfig(internal.GetConfigPath(), appCfg)
}
fmt.Println("Logged out from all providers")
return nil
}
func authStatusCmd() error {
store, err := auth.LoadStore()
if err != nil {
return fmt.Errorf("failed to load auth store: %w", err)
}
if len(store.Credentials) == 0 {
fmt.Println("No authenticated providers.")
fmt.Println("Run: picoclaw auth login --provider
You can close this window.
") resultCh <- callbackResult{code: code} }) listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port)) if err != nil { return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err) } server := &http.Server{Handler: mux} go server.Serve(listener) defer func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() server.Shutdown(ctx) }() fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL) if err := OpenBrowser(authURL); err != nil { fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) } fmt.Printf( "Wait! If you are in a headless environment (like Coolify/VPS) and cannot reach localhost:%d,\n", cfg.Port, ) fmt.Println( "please complete the login in your local browser and then PASTE the final redirect URL (or just the code) here.", ) fmt.Println("Waiting for authentication (browser or manual paste)...") // Start manual input in a goroutine manualCh := make(chan string) go func() { reader := bufio.NewReader(os.Stdin) input, _ := reader.ReadString('\n') manualCh <- strings.TrimSpace(input) }() select { case result := <-resultCh: if result.err != nil { return nil, result.err } return ExchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) case manualInput := <-manualCh: if manualInput == "" { return nil, fmt.Errorf("manual input canceled") } // Extract code from URL if it's a full URL code := manualInput if strings.Contains(manualInput, "?") { u, err := url.Parse(manualInput) if err == nil { code = u.Query().Get("code") } } if code == "" { return nil, fmt.Errorf("could not find authorization code in input") } return ExchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI) case <-time.After(5 * time.Minute): return nil, fmt.Errorf("authentication timed out after 5 minutes") } } type callbackResult struct { code string err error } type deviceCodeResponse struct { DeviceAuthID string UserCode string Interval int } // DeviceCodeInfo holds the device code information returned by the OAuth provider. type DeviceCodeInfo struct { DeviceAuthID string `json:"device_auth_id"` UserCode string `json:"user_code"` VerifyURL string `json:"verify_url"` Interval int `json:"interval"` } // RequestDeviceCode requests a device code from the OAuth provider. // Returns the info needed for the user to authenticate in a browser. func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) { reqBody, _ := json.Marshal(map[string]string{ "client_id": cfg.ClientID, }) resp, err := http.Post( cfg.Issuer+"/api/accounts/deviceauth/usercode", "application/json", strings.NewReader(string(reqBody)), ) if err != nil { return nil, fmt.Errorf("requesting device code: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("reading device code response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device code request failed: %s", string(body)) } deviceResp, err := parseDeviceCodeResponse(body) if err != nil { return nil, fmt.Errorf("parsing device code response: %w", err) } if deviceResp.Interval < 1 { deviceResp.Interval = 5 } return &DeviceCodeInfo{ DeviceAuthID: deviceResp.DeviceAuthID, UserCode: deviceResp.UserCode, VerifyURL: cfg.Issuer + "/codex/device", Interval: deviceResp.Interval, }, nil } // PollDeviceCodeOnce makes a single poll attempt to check if the user has authenticated. // Returns (credential, nil) on success, (nil, nil) if still pending, or (nil, err) on failure. func PollDeviceCodeOnce(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { return pollDeviceCode(cfg, deviceAuthID, userCode) } func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) { var raw struct { DeviceAuthID string `json:"device_auth_id"` UserCode string `json:"user_code"` Interval json.RawMessage `json:"interval"` } if err := json.Unmarshal(body, &raw); err != nil { return deviceCodeResponse{}, err } interval, err := parseFlexibleInt(raw.Interval) if err != nil { return deviceCodeResponse{}, err } return deviceCodeResponse{ DeviceAuthID: raw.DeviceAuthID, UserCode: raw.UserCode, Interval: interval, }, nil } func parseFlexibleInt(raw json.RawMessage) (int, error) { if len(raw) == 0 || string(raw) == "null" { return 0, nil } var interval int if err := json.Unmarshal(raw, &interval); err == nil { return interval, nil } var intervalStr string if err := json.Unmarshal(raw, &intervalStr); err == nil { intervalStr = strings.TrimSpace(intervalStr) if intervalStr == "" { return 0, nil } return strconv.Atoi(intervalStr) } return 0, fmt.Errorf("invalid integer value: %s", string(raw)) } func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { reqBody, _ := json.Marshal(map[string]string{ "client_id": cfg.ClientID, }) resp, err := http.Post( cfg.Issuer+"/api/accounts/deviceauth/usercode", "application/json", strings.NewReader(string(reqBody)), ) if err != nil { return nil, fmt.Errorf("requesting device code: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("reading device code response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device code request failed: %s", string(body)) } deviceResp, err := parseDeviceCodeResponse(body) if err != nil { return nil, fmt.Errorf("parsing device code response: %w", err) } if deviceResp.Interval < 1 { deviceResp.Interval = 5 } fmt.Printf( "\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", cfg.Issuer, deviceResp.UserCode, ) deadline := time.After(15 * time.Minute) ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) defer ticker.Stop() for { select { case <-deadline: return nil, fmt.Errorf("device code authentication timed out after 15 minutes") case <-ticker.C: cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode) if err != nil { continue } if cred != nil { return cred, nil } } } } func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { reqBody, _ := json.Marshal(map[string]string{ "device_auth_id": deviceAuthID, "user_code": userCode, }) resp, err := http.Post( cfg.Issuer+"/api/accounts/deviceauth/token", "application/json", strings.NewReader(string(reqBody)), ) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("pending") } body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("reading device token response: %w", err) } var tokenResp struct { AuthorizationCode string `json:"authorization_code"` CodeChallenge string `json:"code_challenge"` CodeVerifier string `json:"code_verifier"` } if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, err } redirectURI := cfg.Issuer + "/deviceauth/callback" return ExchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI) } func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) { if cred.RefreshToken == "" { return nil, fmt.Errorf("no refresh token available") } data := url.Values{ "client_id": {cfg.ClientID}, "grant_type": {"refresh_token"}, "refresh_token": {cred.RefreshToken}, "scope": {"openid profile email"}, } if cfg.ClientSecret != "" { data.Set("client_secret", cfg.ClientSecret) } tokenURL := cfg.Issuer + "/oauth/token" if cfg.TokenURL != "" { tokenURL = cfg.TokenURL } resp, err := http.PostForm(tokenURL, data) if err != nil { return nil, fmt.Errorf("refreshing token: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("reading token refresh response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token refresh failed: %s", string(body)) } refreshed, err := parseTokenResponse(body, cred.Provider) if err != nil { return nil, err } if refreshed.RefreshToken == "" { refreshed.RefreshToken = cred.RefreshToken } if refreshed.AccountID == "" { refreshed.AccountID = cred.AccountID } if cred.Email != "" && refreshed.Email == "" { refreshed.Email = cred.Email } if cred.ProjectID != "" && refreshed.ProjectID == "" { refreshed.ProjectID = cred.ProjectID } return refreshed, nil } func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { return buildAuthorizeURL(cfg, pkce, state, redirectURI) } func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { params := url.Values{ "response_type": {"code"}, "client_id": {cfg.ClientID}, "redirect_uri": {redirectURI}, "scope": {cfg.Scopes}, "code_challenge": {pkce.CodeChallenge}, "code_challenge_method": {"S256"}, "state": {state}, } isGoogle := strings.Contains(strings.ToLower(cfg.Issuer), "accounts.google.com") if isGoogle { // Google OAuth requires these for refresh token support params.Set("access_type", "offline") params.Set("prompt", "consent") } else { // OpenAI-specific parameters params.Set("id_token_add_organizations", "true") params.Set("codex_cli_simplified_flow", "true") if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") { params.Set("originator", "picoclaw") } if cfg.Originator != "" { params.Set("originator", cfg.Originator) } } // Google uses /auth path, OpenAI uses /oauth/authorize if isGoogle { return cfg.Issuer + "/auth?" + params.Encode() } return cfg.Issuer + "/oauth/authorize?" + params.Encode() } // ExchangeCodeForTokens exchanges an authorization code for tokens. func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { data := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, "redirect_uri": {redirectURI}, "client_id": {cfg.ClientID}, "code_verifier": {codeVerifier}, } if cfg.ClientSecret != "" { data.Set("client_secret", cfg.ClientSecret) } tokenURL := cfg.Issuer + "/oauth/token" if cfg.TokenURL != "" { tokenURL = cfg.TokenURL } // Determine provider name from config provider := "openai" if cfg.TokenURL != "" && strings.Contains(cfg.TokenURL, "googleapis.com") { provider = "google-antigravity" } resp, err := http.PostForm(tokenURL, data) if err != nil { return nil, fmt.Errorf("exchanging code for tokens: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("reading token exchange response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token exchange failed: %s", string(body)) } return parseTokenResponse(body, provider) } func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { var tokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` IDToken string `json:"id_token"` } if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("parsing token response: %w", err) } if tokenResp.AccessToken == "" { return nil, fmt.Errorf("no access token in response") } var expiresAt time.Time if tokenResp.ExpiresIn > 0 { expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) } cred := &AuthCredential{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, ExpiresAt: expiresAt, Provider: provider, AuthMethod: "oauth", } if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { cred.AccountID = accountID } else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { cred.AccountID = accountID } else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { // Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims. cred.AccountID = accountID } return cred, nil } func extractAccountID(token string) string { claims, err := parseJWTClaims(token) if err != nil { return "" } if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" { return accountID } if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" { return accountID } if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]any); ok { if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" { return accountID } } if orgs, ok := claims["organizations"].([]any); ok { for _, org := range orgs { if orgMap, ok := org.(map[string]any); ok { if accountID, ok := orgMap["id"].(string); ok && accountID != "" { return accountID } } } } return "" } func parseJWTClaims(token string) (map[string]any, error) { parts := strings.Split(token, ".") if len(parts) < 2 { return nil, fmt.Errorf("token is not a JWT") } payload := parts[1] switch len(payload) % 4 { case 2: payload += "==" case 3: payload += "=" } decoded, err := base64URLDecode(payload) if err != nil { return nil, err } var claims map[string]any if err := json.Unmarshal(decoded, &claims); err != nil { return nil, err } return claims, nil } func base64URLDecode(s string) ([]byte, error) { s = strings.NewReplacer("-", "+", "_", "/").Replace(s) return base64.StdEncoding.DecodeString(s) } // OpenBrowser opens the given URL in the user's default browser. func OpenBrowser(url string) error { switch runtime.GOOS { case "darwin": return exec.Command("open", url).Start() case "linux": return exec.Command("xdg-open", url).Start() case "windows": return exec.Command("cmd", "/c", "start", url).Start() default: return fmt.Errorf("unsupported platform: %s", runtime.GOOS) } } ================================================ FILE: pkg/auth/oauth_test.go ================================================ package auth import ( "encoding/base64" "encoding/json" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) func makeJWTForClaims(t *testing.T, claims map[string]any) string { t.Helper() header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) payloadJSON, err := json.Marshal(claims) if err != nil { t.Fatalf("marshal claims: %v", err) } payload := base64.RawURLEncoding.EncodeToString(payloadJSON) return header + "." + payload + ".sig" } func TestBuildAuthorizeURL(t *testing.T) { cfg := OAuthProviderConfig{ Issuer: "https://auth.example.com", ClientID: "test-client-id", Scopes: "openid profile", Originator: "codex_cli_rs", Port: 1455, } pkce := PKCECodes{ CodeVerifier: "test-verifier", CodeChallenge: "test-challenge", } u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") if !strings.HasPrefix(u, "https://auth.example.com/oauth/authorize?") { t.Errorf("URL does not start with expected prefix: %s", u) } if !strings.Contains(u, "client_id=test-client-id") { t.Error("URL missing client_id") } if !strings.Contains(u, "code_challenge=test-challenge") { t.Error("URL missing code_challenge") } if !strings.Contains(u, "code_challenge_method=S256") { t.Error("URL missing code_challenge_method") } if !strings.Contains(u, "state=test-state") { t.Error("URL missing state") } if !strings.Contains(u, "response_type=code") { t.Error("URL missing response_type") } if !strings.Contains(u, "id_token_add_organizations=true") { t.Error("URL missing id_token_add_organizations") } if !strings.Contains(u, "codex_cli_simplified_flow=true") { t.Error("URL missing codex_cli_simplified_flow") } if !strings.Contains(u, "originator=codex_cli_rs") { t.Error("URL missing originator") } } func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) { cfg := OpenAIOAuthConfig() pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"} u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") parsed, err := url.Parse(u) if err != nil { t.Fatalf("url.Parse() error: %v", err) } q := parsed.Query() if q.Get("id_token_add_organizations") != "true" { t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations")) } if q.Get("codex_cli_simplified_flow") != "true" { t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow")) } if q.Get("originator") != "codex_cli_rs" { t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator")) } } func TestParseTokenResponse(t *testing.T) { resp := map[string]any{ "access_token": "test-access-token", "refresh_token": "test-refresh-token", "expires_in": 3600, "id_token": "test-id-token", } body, _ := json.Marshal(resp) cred, err := parseTokenResponse(body, "openai") if err != nil { t.Fatalf("parseTokenResponse() error: %v", err) } if cred.AccessToken != "test-access-token" { t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token") } if cred.RefreshToken != "test-refresh-token" { t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token") } if cred.Provider != "openai" { t.Errorf("Provider = %q, want %q", cred.Provider, "openai") } if cred.AuthMethod != "oauth" { t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") } if cred.ExpiresAt.IsZero() { t.Error("ExpiresAt should not be zero") } } func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) { idToken := makeJWTForClaims(t, map[string]any{"chatgpt_account_id": "acc-id-from-id-token"}) resp := map[string]any{ "access_token": "opaque-access-token", "refresh_token": "test-refresh-token", "expires_in": 3600, "id_token": idToken, } body, _ := json.Marshal(resp) cred, err := parseTokenResponse(body, "openai") if err != nil { t.Fatalf("parseTokenResponse() error: %v", err) } if cred.AccountID != "acc-id-from-id-token" { t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token") } } func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) { token := makeJWTForClaims(t, map[string]any{ "organizations": []any{ map[string]any{"id": "org_from_orgs"}, }, }) if got := extractAccountID(token); got != "org_from_orgs" { t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs") } } func TestParseTokenResponseNoAccessToken(t *testing.T) { body := []byte(`{"refresh_token": "test"}`) _, err := parseTokenResponse(body, "openai") if err == nil { t.Error("expected error for missing access_token") } } func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) { idToken := makeJWTWithAccountID("acc-from-id") resp := map[string]any{ "access_token": "not-a-jwt", "refresh_token": "test-refresh-token", "expires_in": 3600, "id_token": idToken, } body, _ := json.Marshal(resp) cred, err := parseTokenResponse(body, "openai") if err != nil { t.Fatalf("parseTokenResponse() error: %v", err) } if cred.AccountID != "acc-from-id" { t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-from-id") } } func makeJWTWithAccountID(accountID string) string { header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) payload := base64.RawURLEncoding.EncodeToString( []byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`), ) return header + "." + payload + ".sig" } func TestExchangeCodeForTokens(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/oauth/token" { http.Error(w, "not found", http.StatusNotFound) return } if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } r.ParseForm() if r.FormValue("grant_type") != "authorization_code" { http.Error(w, "invalid grant_type", http.StatusBadRequest) return } resp := map[string]any{ "access_token": "mock-access-token", "refresh_token": "mock-refresh-token", "expires_in": 3600, } json.NewEncoder(w).Encode(resp) })) defer server.Close() cfg := OAuthProviderConfig{ Issuer: server.URL, ClientID: "test-client", Scopes: "openid", Port: 1455, } cred, err := ExchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback") if err != nil { t.Fatalf("ExchangeCodeForTokens() error: %v", err) } if cred.AccessToken != "mock-access-token" { t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") } } func TestRefreshAccessToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/oauth/token" { http.Error(w, "not found", http.StatusNotFound) return } r.ParseForm() if r.FormValue("grant_type") != "refresh_token" { http.Error(w, "invalid grant_type", http.StatusBadRequest) return } resp := map[string]any{ "access_token": "refreshed-access-token", "refresh_token": "refreshed-refresh-token", "expires_in": 3600, } json.NewEncoder(w).Encode(resp) })) defer server.Close() cfg := OAuthProviderConfig{ Issuer: server.URL, ClientID: "test-client", } cred := &AuthCredential{ AccessToken: "old-token", RefreshToken: "old-refresh-token", Provider: "openai", AuthMethod: "oauth", } refreshed, err := RefreshAccessToken(cred, cfg) if err != nil { t.Fatalf("RefreshAccessToken() error: %v", err) } if refreshed.AccessToken != "refreshed-access-token" { t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token") } if refreshed.RefreshToken != "refreshed-refresh-token" { t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token") } } func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { cfg := OpenAIOAuthConfig() cred := &AuthCredential{ AccessToken: "old-token", Provider: "openai", AuthMethod: "oauth", } _, err := RefreshAccessToken(cred, cfg) if err == nil { t.Error("expected error for missing refresh token") } } func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := map[string]any{ "access_token": "new-access-token-only", "expires_in": 3600, } json.NewEncoder(w).Encode(resp) })) defer server.Close() cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"} cred := &AuthCredential{ AccessToken: "old-access", RefreshToken: "existing-refresh", AccountID: "acc_existing", Provider: "openai", AuthMethod: "oauth", } refreshed, err := RefreshAccessToken(cred, cfg) if err != nil { t.Fatalf("RefreshAccessToken() error: %v", err) } if refreshed.RefreshToken != "existing-refresh" { t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh") } if refreshed.AccountID != "acc_existing" { t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing") } } func TestOpenAIOAuthConfig(t *testing.T) { cfg := OpenAIOAuthConfig() if cfg.Issuer != "https://auth.openai.com" { t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com") } if cfg.ClientID == "" { t.Error("ClientID is empty") } if cfg.Port != 1455 { t.Errorf("Port = %d, want 1455", cfg.Port) } } func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) { body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`) resp, err := parseDeviceCodeResponse(body) if err != nil { t.Fatalf("parseDeviceCodeResponse() error: %v", err) } if resp.DeviceAuthID != "abc" { t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc") } if resp.UserCode != "DEF-1234" { t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234") } if resp.Interval != 5 { t.Errorf("Interval = %d, want %d", resp.Interval, 5) } } func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) { body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`) resp, err := parseDeviceCodeResponse(body) if err != nil { t.Fatalf("parseDeviceCodeResponse() error: %v", err) } if resp.Interval != 5 { t.Errorf("Interval = %d, want %d", resp.Interval, 5) } } func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) { body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`) if _, err := parseDeviceCodeResponse(body); err == nil { t.Fatal("expected error for invalid interval") } } ================================================ FILE: pkg/auth/pkce.go ================================================ package auth import ( "crypto/rand" "crypto/sha256" "encoding/base64" ) type PKCECodes struct { CodeVerifier string CodeChallenge string } func GeneratePKCE() (PKCECodes, error) { buf := make([]byte, 64) if _, err := rand.Read(buf); err != nil { return PKCECodes{}, err } verifier := base64.RawURLEncoding.EncodeToString(buf) hash := sha256.Sum256([]byte(verifier)) challenge := base64.RawURLEncoding.EncodeToString(hash[:]) return PKCECodes{ CodeVerifier: verifier, CodeChallenge: challenge, }, nil } ================================================ FILE: pkg/auth/pkce_test.go ================================================ package auth import ( "crypto/sha256" "encoding/base64" "testing" ) func TestGeneratePKCE(t *testing.T) { codes, err := GeneratePKCE() if err != nil { t.Fatalf("GeneratePKCE() error: %v", err) } if codes.CodeVerifier == "" { t.Fatal("CodeVerifier is empty") } if codes.CodeChallenge == "" { t.Fatal("CodeChallenge is empty") } verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier) if err != nil { t.Fatalf("CodeVerifier is not valid base64url: %v", err) } if len(verifierBytes) != 64 { t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes)) } hash := sha256.Sum256([]byte(codes.CodeVerifier)) expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) if codes.CodeChallenge != expectedChallenge { t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge) } } func TestGeneratePKCEUniqueness(t *testing.T) { codes1, err := GeneratePKCE() if err != nil { t.Fatalf("GeneratePKCE() error: %v", err) } codes2, err := GeneratePKCE() if err != nil { t.Fatalf("GeneratePKCE() error: %v", err) } if codes1.CodeVerifier == codes2.CodeVerifier { t.Error("two GeneratePKCE() calls produced identical verifiers") } } ================================================ FILE: pkg/auth/store.go ================================================ package auth import ( "encoding/json" "os" "path/filepath" "time" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/fileutil" ) type AuthCredential struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` AccountID string `json:"account_id,omitempty"` ExpiresAt time.Time `json:"expires_at,omitempty"` Provider string `json:"provider"` AuthMethod string `json:"auth_method"` Email string `json:"email,omitempty"` ProjectID string `json:"project_id,omitempty"` } type AuthStore struct { Credentials map[string]*AuthCredential `json:"credentials"` } func (c *AuthCredential) IsExpired() bool { if c.ExpiresAt.IsZero() { return false } return time.Now().After(c.ExpiresAt) } func (c *AuthCredential) NeedsRefresh() bool { if c.ExpiresAt.IsZero() { return false } return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) } func authFilePath() string { if home := os.Getenv(config.EnvHome); home != "" { return filepath.Join(home, "auth.json") } home, _ := os.UserHomeDir() return filepath.Join(home, ".picoclaw", "auth.json") } func LoadStore() (*AuthStore, error) { path := authFilePath() data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil } return nil, err } var store AuthStore if err := json.Unmarshal(data, &store); err != nil { return nil, err } if store.Credentials == nil { store.Credentials = make(map[string]*AuthCredential) } return &store, nil } func SaveStore(store *AuthStore) error { path := authFilePath() data, err := json.MarshalIndent(store, "", " ") if err != nil { return err } // Use unified atomic write utility with explicit sync for flash storage reliability. return fileutil.WriteFileAtomic(path, data, 0o600) } func GetCredential(provider string) (*AuthCredential, error) { store, err := LoadStore() if err != nil { return nil, err } cred, ok := store.Credentials[provider] if !ok { return nil, nil } return cred, nil } func SetCredential(provider string, cred *AuthCredential) error { store, err := LoadStore() if err != nil { return err } store.Credentials[provider] = cred return SaveStore(store) } func DeleteCredential(provider string) error { store, err := LoadStore() if err != nil { return err } delete(store.Credentials, provider) return SaveStore(store) } func DeleteAllCredentials() error { path := authFilePath() if err := os.Remove(path); err != nil && !os.IsNotExist(err) { return err } return nil } ================================================ FILE: pkg/auth/store_test.go ================================================ package auth import ( "os" "path/filepath" "testing" "time" ) func TestAuthCredentialIsExpired(t *testing.T) { tests := []struct { name string expiresAt time.Time want bool }{ {"zero time", time.Time{}, false}, {"future", time.Now().Add(time.Hour), false}, {"past", time.Now().Add(-time.Hour), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &AuthCredential{ExpiresAt: tt.expiresAt} if got := c.IsExpired(); got != tt.want { t.Errorf("IsExpired() = %v, want %v", got, tt.want) } }) } } func TestAuthCredentialNeedsRefresh(t *testing.T) { tests := []struct { name string expiresAt time.Time want bool }{ {"zero time", time.Time{}, false}, {"far future", time.Now().Add(time.Hour), false}, {"within 5 min", time.Now().Add(3 * time.Minute), true}, {"already expired", time.Now().Add(-time.Minute), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &AuthCredential{ExpiresAt: tt.expiresAt} if got := c.NeedsRefresh(); got != tt.want { t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) } }) } } func TestStoreRoundtrip(t *testing.T) { tmpDir := t.TempDir() origHome := os.Getenv("HOME") t.Setenv("HOME", tmpDir) defer os.Setenv("HOME", origHome) cred := &AuthCredential{ AccessToken: "test-access-token", RefreshToken: "test-refresh-token", AccountID: "acct-123", ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), Provider: "openai", AuthMethod: "oauth", } if err := SetCredential("openai", cred); err != nil { t.Fatalf("SetCredential() error: %v", err) } loaded, err := GetCredential("openai") if err != nil { t.Fatalf("GetCredential() error: %v", err) } if loaded == nil { t.Fatal("GetCredential() returned nil") } if loaded.AccessToken != cred.AccessToken { t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken) } if loaded.RefreshToken != cred.RefreshToken { t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken) } if loaded.Provider != cred.Provider { t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider) } } func TestStoreFilePermissions(t *testing.T) { tmpDir := t.TempDir() origHome := os.Getenv("HOME") t.Setenv("HOME", tmpDir) defer os.Setenv("HOME", origHome) cred := &AuthCredential{ AccessToken: "secret-token", Provider: "openai", AuthMethod: "oauth", } if err := SetCredential("openai", cred); err != nil { t.Fatalf("SetCredential() error: %v", err) } path := filepath.Join(tmpDir, ".picoclaw", "auth.json") info, err := os.Stat(path) if err != nil { t.Fatalf("Stat() error: %v", err) } perm := info.Mode().Perm() if perm != 0o600 { t.Errorf("file permissions = %o, want 0600", perm) } } func TestStoreMultiProvider(t *testing.T) { tmpDir := t.TempDir() origHome := os.Getenv("HOME") t.Setenv("HOME", tmpDir) defer os.Setenv("HOME", origHome) openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} if err := SetCredential("openai", openaiCred); err != nil { t.Fatalf("SetCredential(openai) error: %v", err) } if err := SetCredential("anthropic", anthropicCred); err != nil { t.Fatalf("SetCredential(anthropic) error: %v", err) } loaded, err := GetCredential("openai") if err != nil { t.Fatalf("GetCredential(openai) error: %v", err) } if loaded.AccessToken != "openai-token" { t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token") } loaded, err = GetCredential("anthropic") if err != nil { t.Fatalf("GetCredential(anthropic) error: %v", err) } if loaded.AccessToken != "anthropic-token" { t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token") } } func TestDeleteCredential(t *testing.T) { tmpDir := t.TempDir() origHome := os.Getenv("HOME") t.Setenv("HOME", tmpDir) defer os.Setenv("HOME", origHome) cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} if err := SetCredential("openai", cred); err != nil { t.Fatalf("SetCredential() error: %v", err) } if err := DeleteCredential("openai"); err != nil { t.Fatalf("DeleteCredential() error: %v", err) } loaded, err := GetCredential("openai") if err != nil { t.Fatalf("GetCredential() error: %v", err) } if loaded != nil { t.Error("expected nil after delete") } } func TestLoadStoreEmpty(t *testing.T) { tmpDir := t.TempDir() origHome := os.Getenv("HOME") t.Setenv("HOME", tmpDir) defer os.Setenv("HOME", origHome) store, err := LoadStore() if err != nil { t.Fatalf("LoadStore() error: %v", err) } if store == nil { t.Fatal("LoadStore() returned nil") } if len(store.Credentials) != 0 { t.Errorf("expected empty credentials, got %d", len(store.Credentials)) } } ================================================ FILE: pkg/auth/token.go ================================================ package auth import ( "bufio" "fmt" "io" "strings" ) func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider)) fmt.Print("> ") scanner := bufio.NewScanner(r) if !scanner.Scan() { if err := scanner.Err(); err != nil { return nil, fmt.Errorf("reading token: %w", err) } return nil, fmt.Errorf("no input received") } token := strings.TrimSpace(scanner.Text()) if token == "" { return nil, fmt.Errorf("token cannot be empty") } return &AuthCredential{ AccessToken: token, Provider: provider, AuthMethod: "token", }, nil } func LoginSetupToken(r io.Reader) (*AuthCredential, error) { fmt.Println("Paste your setup token from `claude setup-token`:") fmt.Print("> ") scanner := bufio.NewScanner(r) if !scanner.Scan() { if err := scanner.Err(); err != nil { return nil, fmt.Errorf("reading token: %w", err) } return nil, fmt.Errorf("no input received") } token := strings.TrimSpace(scanner.Text()) if !strings.HasPrefix(token, "sk-ant-oat01-") { return nil, fmt.Errorf("invalid setup token: expected prefix sk-ant-oat01-") } if len(token) < 80 { return nil, fmt.Errorf("invalid setup token: too short (expected at least 80 characters)") } return &AuthCredential{ AccessToken: token, Provider: "anthropic", AuthMethod: "oauth", }, nil } func providerDisplayName(provider string) string { switch provider { case "anthropic": return "console.anthropic.com" case "openai": return "platform.openai.com" default: return provider } } ================================================ FILE: pkg/auth/token_test.go ================================================ package auth import ( "strings" "testing" ) func TestLoginSetupToken(t *testing.T) { // A valid token: correct prefix + at least 80 chars validToken := "sk-ant-oat01-" + strings.Repeat("a", 80) tests := []struct { name string input string wantErr string }{ {"valid token", validToken, ""}, {"empty input", "", "expected prefix sk-ant-oat01-"}, {"wrong prefix", "sk-ant-api-" + strings.Repeat("a", 80), "expected prefix sk-ant-oat01-"}, {"too short", "sk-ant-oat01-short", "too short"}, {"whitespace only", " ", "expected prefix sk-ant-oat01-"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := strings.NewReader(tt.input + "\n") cred, err := LoginSetupToken(r) if tt.wantErr != "" { if err == nil { t.Fatalf("expected error containing %q, got nil", tt.wantErr) } if !strings.Contains(err.Error(), tt.wantErr) { t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if cred.AccessToken != validToken { t.Errorf("AccessToken = %q, want %q", cred.AccessToken, validToken) } if cred.Provider != "anthropic" { t.Errorf("Provider = %q, want %q", cred.Provider, "anthropic") } if cred.AuthMethod != "oauth" { t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") } }) } } func TestLoginSetupToken_EmptyReader(t *testing.T) { r := strings.NewReader("") _, err := LoginSetupToken(r) if err == nil { t.Fatal("expected error for empty reader, got nil") } } ================================================ FILE: pkg/bus/bus.go ================================================ package bus import ( "context" "errors" "sync" "sync/atomic" "github.com/sipeed/picoclaw/pkg/logger" ) // ErrBusClosed is returned when publishing to a closed MessageBus. var ErrBusClosed = errors.New("message bus closed") const defaultBusBufferSize = 64 type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage outboundMedia chan OutboundMediaMessage closeOnce sync.Once done chan struct{} closed atomic.Bool wg sync.WaitGroup } func NewMessageBus() *MessageBus { return &MessageBus{ inbound: make(chan InboundMessage, defaultBusBufferSize), outbound: make(chan OutboundMessage, defaultBusBufferSize), outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize), done: make(chan struct{}), } } func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error { // check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock if mb.closed.Load() { return ErrBusClosed } // check again,before sending message, to avoid sending to closed channel select { case <-ctx.Done(): return ctx.Err() case <-mb.done: return ErrBusClosed default: } mb.wg.Add(1) defer mb.wg.Done() select { case ch <- msg: return nil case <-ctx.Done(): return ctx.Err() case <-mb.done: return ErrBusClosed } } func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { return publish(ctx, mb, mb.inbound, msg) } func (mb *MessageBus) InboundChan() <-chan InboundMessage { return mb.inbound } func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { return publish(ctx, mb, mb.outbound, msg) } func (mb *MessageBus) OutboundChan() <-chan OutboundMessage { return mb.outbound } func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { return publish(ctx, mb, mb.outboundMedia, msg) } func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage { return mb.outboundMedia } func (mb *MessageBus) Close() { mb.closeOnce.Do(func() { // notify all blocked publishers to exit close(mb.done) // because every publisher will check mb.closed before acquiring wg // so we can be sure that new publishers will not be added new messages after this point mb.closed.Store(true) // wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited mb.wg.Wait() // close channels safely close(mb.inbound) close(mb.outbound) close(mb.outboundMedia) // clean up any remaining messages in channels drained := 0 for range mb.inbound { drained++ } for range mb.outbound { drained++ } for range mb.outboundMedia { drained++ } if drained > 0 { logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ "count": drained, }) } }) } ================================================ FILE: pkg/bus/bus_test.go ================================================ package bus import ( "context" "sync" "testing" "time" ) func TestPublishConsume(t *testing.T) { mb := NewMessageBus() defer mb.Close() ctx := context.Background() msg := InboundMessage{ Channel: "test", SenderID: "user1", ChatID: "chat1", Content: "hello", } if err := mb.PublishInbound(ctx, msg); err != nil { t.Fatalf("PublishInbound failed: %v", err) } got, ok := <-mb.InboundChan() if !ok { t.Fatal("ConsumeInbound returned ok=false") } if got.Content != "hello" { t.Fatalf("expected content 'hello', got %q", got.Content) } if got.Channel != "test" { t.Fatalf("expected channel 'test', got %q", got.Channel) } } func TestPublishOutboundSubscribe(t *testing.T) { mb := NewMessageBus() defer mb.Close() ctx := context.Background() msg := OutboundMessage{ Channel: "telegram", ChatID: "123", Content: "world", } if err := mb.PublishOutbound(ctx, msg); err != nil { t.Fatalf("PublishOutbound failed: %v", err) } got, ok := <-mb.OutboundChan() if !ok { t.Fatal("SubscribeOutbound returned ok=false") } if got.Content != "world" { t.Fatalf("expected content 'world', got %q", got.Content) } } func TestPublishInbound_ContextCancel(t *testing.T) { mb := NewMessageBus() defer mb.Close() // Fill the buffer ctx := context.Background() for i := range defaultBusBufferSize { if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } // Now buffer is full; publish with a canceled context cancelCtx, cancel := context.WithCancel(context.Background()) cancel() err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) if err == nil { t.Fatal("expected error from canceled context, got nil") } if err != context.Canceled { t.Fatalf("expected context.Canceled, got %v", err) } } func TestPublishInbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed, got %v", err) } } func TestPublishOutbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed, got %v", err) } } func TestConsumeInbound_ContextCancel(t *testing.T) { mb := NewMessageBus() defer mb.Close() for i := range defaultBusBufferSize { if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"}) select { case <-ctx.Done(): t.Log("context canceled, as expected") case msg, ok := <-mb.InboundChan(): if !ok { t.Fatal("expected ok=false when context is canceled") } if msg.Content == "ContextCancel" { t.Fatalf("expected content 'ContextCancel', got %q", msg.Content) } } } func TestConsumeInbound_BusClosed(t *testing.T) { mb := NewMessageBus() timer := time.AfterFunc(100*time.Millisecond, func() { mb.Close() }) select { case <-timer.C: t.Log("context canceled, as expected") case _, ok := <-mb.InboundChan(): if ok { t.Fatal("expected ok=false when context is canceled") } } } func TestSubscribeOutbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() _, ok := <-mb.OutboundChan() if ok { t.Fatal("expected ok=false when bus is closed") } } func TestConcurrentPublishClose(t *testing.T) { mb := NewMessageBus() ctx := context.Background() const numGoroutines = 100 var wg sync.WaitGroup wg.Add(numGoroutines + 1) // Spawn many goroutines trying to publish for range numGoroutines { go func() { defer wg.Done() // Use a short timeout context so we don't block forever after close publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) defer cancel() // Errors are expected; we just must not panic or deadlock _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"}) }() } // Close from another goroutine go func() { defer wg.Done() time.Sleep(5 * time.Millisecond) mb.Close() }() // Must complete without deadlock done := make(chan struct{}) go func() { wg.Wait() close(done) }() select { case <-done: // success case <-time.After(5 * time.Second): t.Fatal("test timed out - possible deadlock") } } func TestPublishInbound_FullBuffer(t *testing.T) { mb := NewMessageBus() defer mb.Close() ctx := context.Background() // Fill the buffer for i := range defaultBusBufferSize { if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } // Buffer is full; publish with short timeout timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) if err == nil { t.Fatal("expected error when buffer is full and context times out") } if err != context.DeadlineExceeded { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } } func TestCloseIdempotent(t *testing.T) { mb := NewMessageBus() // Multiple Close calls must not panic mb.Close() mb.Close() mb.Close() // After close, publish should return ErrBusClosed err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) } } ================================================ FILE: pkg/bus/types.go ================================================ package bus // Peer identifies the routing peer for a message (direct, group, channel, etc.) type Peer struct { Kind string `json:"kind"` // "direct" | "group" | "channel" | "" ID string `json:"id"` } // SenderInfo provides structured sender identity information. type SenderInfo struct { Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ... PlatformID string `json:"platform_id,omitempty"` // raw platform ID, e.g. "123456" CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" format Username string `json:"username,omitempty"` // username (e.g. @alice) DisplayName string `json:"display_name,omitempty"` // display name } type InboundMessage struct { Channel string `json:"channel"` SenderID string `json:"sender_id"` Sender SenderInfo `json:"sender"` ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` Peer Peer `json:"peer"` // routing peer MessageID string `json:"message_id,omitempty"` // platform message ID MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope SessionKey string `json:"session_key"` Metadata map[string]string `json:"metadata,omitempty"` } type OutboundMessage struct { Channel string `json:"channel"` ChatID string `json:"chat_id"` Content string `json:"content"` ReplyToMessageID string `json:"reply_to_message_id,omitempty"` } // MediaPart describes a single media attachment to send. type MediaPart struct { Type string `json:"type"` // "image" | "audio" | "video" | "file" Ref string `json:"ref"` // media store ref, e.g. "media://abc123" Caption string `json:"caption,omitempty"` // optional caption text Filename string `json:"filename,omitempty"` // original filename hint ContentType string `json:"content_type,omitempty"` // MIME type hint } // OutboundMediaMessage carries media attachments from Agent to channels via the bus. type OutboundMediaMessage struct { Channel string `json:"channel"` ChatID string `json:"chat_id"` Parts []MediaPart `json:"parts"` } ================================================ FILE: pkg/channels/README.md ================================================ # PicoClaw Channel System: Complete Development Guide > **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/` --- ## Table of Contents - [Part 1: Architecture Overview](#part-1-architecture-overview) - [Part 2: Migration Guide — From main Branch to Refactored Branch](#part-2-migration-guide--from-main-branch-to-refactored-branch) - [Part 3: New Channel Development Guide — Implementing a Channel from Scratch](#part-3-new-channel-development-guide--implementing-a-channel-from-scratch) - [Part 4: Core Subsystem Details](#part-4-core-subsystem-details) - [Part 5: Key Design Decisions and Conventions](#part-5-key-design-decisions-and-conventions) - [Appendix: Complete File Listing and Interface Quick Reference](#appendix-complete-file-listing-and-interface-quick-reference) --- ## Part 1: Architecture Overview ### 1.1 Before and After Comparison **Before Refactor (main branch)**: ``` pkg/channels/ ├── telegram.go # Each channel directly in the channels package ├── discord.go ├── slack.go ├── manager.go # Manager directly references each channel type ├── ... ``` - All channel implementations lived at the top level of `pkg/channels/` - Manager constructed each channel via `switch` or `if-else` chains - Routing info like Peer and MessageID was buried in `Metadata map[string]string` - No rate limiting or retry on message sending - No unified media file lifecycle management - Each channel ran its own HTTP server - Group chat trigger filtering logic was scattered across channels **After Refactor (refactor/channel-system branch)**: ``` pkg/channels/ ├── base.go # BaseChannel shared abstraction layer ├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder) ├── README.md # English documentation ├── README.zh.md # Chinese documentation ├── media.go # MediaSender optional interface ├── webhook.go # WebhookHandler, HealthChecker optional interfaces ├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed) ├── errutil.go # Error classification helpers ├── registry.go # Factory registry (RegisterFactory / getFactory) ├── manager.go # Unified orchestration: Worker queues, rate limiting, retries, Typing/Placeholder, shared HTTP ├── split.go # Smart long-message splitting (preserves code block integrity) ├── telegram/ # Each channel in its own sub-package │ ├── init.go # Factory registration │ ├── telegram.go # Implementation │ └── telegram_commands.go ├── discord/ │ ├── init.go │ └── discord.go ├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/ │ └── ... pkg/bus/ ├── bus.go # MessageBus (buffer 64, safe close + drain) ├── types.go # Structured message types (Peer, SenderInfo, MediaPart, InboundMessage, OutboundMessage, OutboundMediaMessage) pkg/media/ ├── store.go # MediaStore interface + FileMediaStore implementation (two-phase release, TTL cleanup) pkg/identity/ ├── identity.go # Unified user identity: canonical "platform:id" format + backward-compatible matching ``` ### 1.2 Message Flow Overview ``` ┌────────────┐ InboundMessage ┌───────────┐ LLM + Tools ┌────────────┐ │ Telegram │──┐ │ │ │ │ │ Discord │──┤ PublishInbound() │ │ PublishOutbound() │ │ │ Slack │──┼──────────────────────▶ │ MessageBus │ ◀─────────────────── │ AgentLoop │ │ LINE │──┤ (buffered chan, 64) │ │ (buffered chan, 64) │ │ │ ... │──┘ │ │ │ │ └────────────┘ └─────┬─────┘ └────────────┘ │ SubscribeOutbound() │ SubscribeOutboundMedia() ▼ ┌───────────────────┐ │ Manager │ │ ├── dispatchOutbound() Route to Worker queues │ ├── dispatchOutboundMedia() │ ├── runWorker() Message split + sendWithRetry() │ ├── runMediaWorker() sendMediaWithRetry() │ ├── preSend() Stop Typing + Undo Reaction + Edit Placeholder │ └── runTTLJanitor() Clean up expired Typing/Placeholder └────────┬──────────┘ │ channel.Send() / SendMedia() │ ▼ ┌────────────────┐ │ Platform APIs │ └────────────────┘ ``` ### 1.3 Key Design Principles | Principle | Description | |-----------|-------------| | **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package | | **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling | | **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`), discovered by Manager via runtime type assertions | | **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage | | **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy | | **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send | --- ## Part 2: Migration Guide — From main Branch to Refactored Branch ### 2.1 If You Have Unmerged Channel Changes #### Step 1: Identify which files you modified On the main branch, channel files were directly in `pkg/channels/` top level, e.g.: - `pkg/channels/telegram.go` - `pkg/channels/discord.go` After refactoring, these files have been removed and code moved to corresponding sub-packages: - `pkg/channels/telegram/telegram.go` - `pkg/channels/discord/discord.go` #### Step 2: Understand the structural change mapping | main branch file | Refactored branch location | Changes | |---|---|---| | `pkg/channels/telegram.go` | `pkg/channels/telegram/telegram.go` + `init.go` | Package name changed from `channels` to `telegram` | | `pkg/channels/discord.go` | `pkg/channels/discord/discord.go` + `init.go` | Same as above | | `pkg/channels/manager.go` | `pkg/channels/manager.go` | Extensively rewritten | | _(did not exist)_ | `pkg/channels/base.go` | New shared abstraction layer | | _(did not exist)_ | `pkg/channels/registry.go` | New factory registry | | _(did not exist)_ | `pkg/channels/errors.go` + `errutil.go` | New error classification system | | _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces | | _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface | | _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker | | _(did not exist)_ | `pkg/channels/whatsapp_native/` | New WhatsApp native mode (whatsmeow) | | _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) | | _(did not exist)_ | `pkg/bus/types.go` | New structured message types | | _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management | | _(did not exist)_ | `pkg/identity/identity.go` | New unified user identity | #### Step 3: Migrate your channel code Using Telegram as an example, the main changes are: **3a. Package declaration and imports** ```go // Old code (main branch) package channels import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" ) // New code (refactored branch) package telegram import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" // Reference parent package "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" // New "github.com/sipeed/picoclaw/pkg/media" // New (if media support needed) ) ``` **3b. Struct embeds BaseChannel** ```go // Old code: directly held bus, config, etc. fields type TelegramChannel struct { bus *bus.MessageBus config *config.Config running bool allowList []string // ... } // New code: embed BaseChannel, which provides bus, running, allowList, etc. type TelegramChannel struct { *channels.BaseChannel // Embed shared abstraction bot *telego.Bot config *config.Config // ... only channel-specific fields } ``` **3c. Constructor** ```go // Old code: direct assignment func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { return &TelegramChannel{ bus: bus, config: cfg, allowList: cfg.Channels.Telegram.AllowFrom, // ... }, nil } // New code: use NewBaseChannel + functional options func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { base := channels.NewBaseChannel( "telegram", // Name cfg.Channels.Telegram, // Raw config (any type) bus, // Message bus cfg.Channels.Telegram.AllowFrom, // Allow list channels.WithMaxMessageLength(4096), // Platform message length limit channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // Reasoning chain routing ) return &TelegramChannel{ BaseChannel: base, bot: bot, config: cfg, }, nil } ``` **3d. Start/Stop lifecycle** ```go // New code: use SetRunning atomic operation func (c *TelegramChannel) Start(ctx context.Context) error { // ... initialize bot, webhook, etc. c.SetRunning(true) // Must be called after ready go bh.Start() return nil } func (c *TelegramChannel) Stop(ctx context.Context) error { c.SetRunning(false) // Must be called before cleanup // ... stop bot handler, cancel context return nil } ``` **3e. Send method error returns** ```go // Old code: returns plain error func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.running { return fmt.Errorf("not running") } // ... if err != nil { return err } } // New code: must return sentinel errors for Manager to determine retry strategy func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning // ← Manager will not retry } // ... if err != nil { // Use ClassifySendError to wrap error based on HTTP status code return channels.ClassifySendError(statusCode, err) // Or manually wrap: // return fmt.Errorf("%w: %v", channels.ErrTemporary, err) // return fmt.Errorf("%w: %v", channels.ErrRateLimit, err) // return fmt.Errorf("%w: %v", channels.ErrSendFailed, err) } return nil } ``` **3f. Message reception (Inbound)** ```go // Old code: directly construct InboundMessage and publish msg := bus.InboundMessage{ Channel: "telegram", SenderID: senderID, ChatID: chatID, Content: content, Metadata: map[string]string{ "peer_kind": "group", // Routing info buried in metadata "peer_id": chatID, "message_id": msgID, }, } c.bus.PublishInbound(ctx, msg) // New code: use BaseChannel.HandleMessage with structured fields sender := bus.SenderInfo{ Platform: "telegram", PlatformID: strconv.FormatInt(from.ID, 10), CanonicalID: identity.BuildCanonicalID("telegram", strconv.FormatInt(from.ID, 10)), Username: from.Username, DisplayName: from.FirstName, } peer := bus.Peer{ Kind: "group", // or "direct" ID: chatID, } // HandleMessage internally calls IsAllowedSender for permission checks, builds MediaScope, and publishes to bus c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, sender) ``` **3g. Add factory registration (required)** Create `init.go` for your channel: ```go // pkg/channels/telegram/init.go package telegram import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) func init() { channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { return NewTelegramChannel(cfg, b) }) } ``` **3h. Import sub-package in Gateway** ```go // cmd/picoclaw/internal/gateway/helpers.go import ( _ "github.com/sipeed/picoclaw/pkg/channels/telegram" // Triggers init() registration _ "github.com/sipeed/picoclaw/pkg/channels/discord" _ "github.com/sipeed/picoclaw/pkg/channels/your_new_channel" // New addition ) ``` #### Step 4: Migrate bus message usage If your code directly reads routing fields from `InboundMessage.Metadata`: ```go // Old code peerKind := msg.Metadata["peer_kind"] peerID := msg.Metadata["peer_id"] msgID := msg.Metadata["message_id"] // New code peerKind := msg.Peer.Kind // First-class field peerID := msg.Peer.ID // First-class field msgID := msg.MessageID // First-class field sender := msg.Sender // bus.SenderInfo struct scope := msg.MediaScope // Media lifecycle scope ``` #### Step 5: Migrate allow-list checks ```go // Old code if !c.isAllowed(senderID) { return } // New code: prefer structured check if !c.IsAllowedSender(sender) { return } // Or fall back to string check: if !c.IsAllowed(senderID) { return } ``` `BaseChannel.HandleMessage` already handles this logic internally — no need to duplicate the check in your channel. ### 2.2 If You Have Manager Modifications The Manager has been completely rewritten. Your modifications will need to account for the new architecture: | Old Manager Responsibility | New Manager Responsibility | |---|---| | Directly construct channels (switch/if-else) | Look up and construct via factory registry | | Directly call channel.Send | Per-channel Worker queues + rate limiting + retries | | No message splitting | Automatic splitting based on MaxMessageLength | | Each channel runs its own HTTP server | Unified shared HTTP server | | No Typing/Placeholder management | Unified preSend handles Typing stop + Reaction undo + Placeholder edit; inbound-side BaseChannel.HandleMessage auto-orchestrates Typing/Reaction/Placeholder | | No TTL cleanup | runTTLJanitor periodically cleans up expired Typing/Reaction/Placeholder entries | ### 2.3 If You Have Agent Loop Modifications Main changes to the Agent Loop: 1. **MediaStore injection**: `agentLoop.SetMediaStore(mediaStore)` — Agent resolves media references produced by tools via MediaStore 2. **ChannelManager injection**: `agentLoop.SetChannelManager(channelManager)` — Agent can query channel state 3. **OutboundMediaMessage**: Agent now sends media messages via `bus.PublishOutboundMedia()` instead of embedding them in text replies 4. **extractPeer**: Routing uses `msg.Peer` structured fields instead of Metadata lookups --- ## Part 3: New Channel Development Guide — Implementing a Channel from Scratch ### 3.1 Minimum Implementation Checklist To add a new chat platform (e.g., `matrix`), you need to: 1. ✅ Create sub-package directory `pkg/channels/matrix/` 2. ✅ Create `init.go` — factory registration 3. ✅ Create `matrix.go` — channel implementation 4. ✅ Add blank import in Gateway helpers 5. ✅ Add config check in Manager.initChannels() 6. ✅ Add config struct in `pkg/config/` ### 3.2 Complete Template #### `pkg/channels/matrix/init.go` ```go package matrix import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) func init() { channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { return NewMatrixChannel(cfg, b) }) } ``` #### `pkg/channels/matrix/matrix.go` ```go package matrix import ( "context" "fmt" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) // MatrixChannel implements channels.Channel for the Matrix protocol. type MatrixChannel struct { *channels.BaseChannel // Must embed config *config.Config ctx context.Context cancel context.CancelFunc // ... Matrix SDK client, etc. } func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChannel, error) { matrixCfg := cfg.Channels.Matrix // Assumes this field exists in config base := channels.NewBaseChannel( "matrix", // Channel name (globally unique) matrixCfg, // Raw config msgBus, // Message bus matrixCfg.AllowFrom, // Allow list channels.WithMaxMessageLength(65536), // Matrix message length limit channels.WithGroupTrigger(matrixCfg.GroupTrigger), channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // Reasoning chain routing (optional) ) return &MatrixChannel{ BaseChannel: base, config: cfg, }, nil } // ========== Required Channel Interface Methods ========== func (c *MatrixChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) // 1. Initialize Matrix client // 2. Start listening for messages // 3. Mark as running c.SetRunning(true) logger.InfoC("matrix", "Matrix channel started") return nil } func (c *MatrixChannel) Stop(ctx context.Context) error { c.SetRunning(false) if c.cancel != nil { c.cancel() } logger.InfoC("matrix", "Matrix channel stopped") return nil } func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { // 1. Check running state if !c.IsRunning() { return channels.ErrNotRunning } // 2. Send message to Matrix err := c.sendToMatrix(ctx, msg.ChatID, msg.Content) if err != nil { // 3. Must use error classification wrapping // If you have an HTTP status code: // return channels.ClassifySendError(statusCode, err) // If it's a network error: // return channels.ClassifyNetError(err) // If manual classification is needed: return fmt.Errorf("%w: %v", channels.ErrTemporary, err) } return nil } // ========== Incoming Message Handling ========== func (c *MatrixChannel) handleIncoming(roomID, senderID, displayName, content string, msgID string) { // 1. Construct structured sender identity sender := bus.SenderInfo{ Platform: "matrix", PlatformID: senderID, CanonicalID: identity.BuildCanonicalID("matrix", senderID), Username: senderID, DisplayName: displayName, } // 2. Determine Peer type (direct vs group) peer := bus.Peer{ Kind: "group", // or "direct" ID: roomID, } // 3. Group chat filtering (if applicable) isGroup := peer.Kind == "group" if isGroup { isMentioned := false // Detect @mentions based on platform specifics shouldRespond, cleanContent := c.ShouldRespondInGroup(isMentioned, content) if !shouldRespond { return } content = cleanContent } // 4. Handle media attachments (if any) var mediaRefs []string store := c.GetMediaStore() if store != nil { // Download attachment locally → store.Store() → get ref // mediaRefs = append(mediaRefs, ref) } // 5. Call HandleMessage to publish to bus // HandleMessage internally will: // - Check IsAllowedSender/IsAllowed // - Build MediaScope // - Publish InboundMessage c.HandleMessage( c.ctx, peer, msgID, // Platform message ID senderID, // Raw sender ID roomID, // Chat/room ID content, // Message content mediaRefs, // Media reference list nil, // Extra metadata (usually nil) sender, // SenderInfo (variadic parameter) ) } // ========== Internal Methods ========== func (c *MatrixChannel) sendToMatrix(ctx context.Context, roomID, content string) error { // Actual Matrix SDK call return nil } ``` ### 3.3 Optional Capability Interfaces Depending on platform capabilities, your channel can optionally implement the following interfaces: #### MediaSender — Send Media Attachments ```go // If the platform supports sending images/files/audio/video func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { if !c.IsRunning() { return channels.ErrNotRunning } store := c.GetMediaStore() if store == nil { return fmt.Errorf("no media store: %w", channels.ErrSendFailed) } for _, part := range msg.Parts { localPath, err := store.Resolve(part.Ref) if err != nil { logger.ErrorCF("matrix", "Failed to resolve media", map[string]any{ "ref": part.Ref, "error": err.Error(), }) continue } // Call the appropriate API based on part.Type ("image"|"audio"|"video"|"file") switch part.Type { case "image": // Upload image to Matrix default: // Upload file to Matrix } } return nil } ``` #### TypingCapable — Typing Indicator ```go // If the platform supports "typing..." indicators func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (stop func(), err error) { // Call Matrix API to send typing indicator // The returned stop function must be idempotent stopped := false return func() { if !stopped { stopped = true // Call Matrix API to stop typing } }, nil } ``` #### ReactionCapable — Message Reaction Indicator ```go // If the platform supports adding emoji reactions to inbound messages (e.g., Slack's 👀, OneBot's emoji 289) func (c *MatrixChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) { // Call Matrix API to add reaction to message // The returned undo function removes the reaction, must be idempotent err = c.addReaction(chatID, messageID, "eyes") if err != nil { return func() {}, err } return func() { c.removeReaction(chatID, messageID, "eyes") }, nil } ``` #### MessageEditor — Message Editing ```go // If the platform supports editing sent messages (used for Placeholder replacement) func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { // Call Matrix API to edit message return nil } ``` #### PlaceholderCapable — Placeholder Messages ```go // If the platform supports sending placeholder messages (e.g. "Thinking... 💭"), // and the channel also implements MessageEditor, then Manager's preSend will // automatically edit the placeholder into the final response on outbound. // SendPlaceholder checks PlaceholderConfig.Enabled internally; // returning ("", nil) means skip. func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { cfg := c.config.Channels.Matrix.Placeholder if !cfg.Enabled { return "", nil } text := cfg.Text if text == "" { text = "Thinking... 💭" } // Call Matrix API to send placeholder message msg, err := c.sendText(ctx, chatID, text) if err != nil { return "", err } return msg.ID, nil } ``` #### WebhookHandler — HTTP Webhook Reception ```go // If the channel receives messages via webhook (rather than long-polling/WebSocket) func (c *MatrixChannel) WebhookPath() string { return "/webhook/matrix" // Path will be registered on the shared HTTP server } func (c *MatrixChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Handle webhook request } ``` #### HealthChecker — Health Check Endpoint ```go func (c *MatrixChannel) HealthPath() string { return "/health/matrix" } func (c *MatrixChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { if c.IsRunning() { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) } else { w.WriteHeader(http.StatusServiceUnavailable) } } ``` ### 3.4 Inbound-side Typing/Reaction/Placeholder Auto-orchestration `BaseChannel.HandleMessage` automatically detects whether the channel implements `TypingCapable`, `ReactionCapable`, and/or `PlaceholderCapable` **before** publishing the inbound message, and triggers the corresponding indicators. The three pipelines are completely independent and do not interfere with each other: ```go // Automatically executed inside BaseChannel.HandleMessage (no manual calls needed): if c.owner != nil && c.placeholderRecorder != nil { // Typing — independent pipeline if tc, ok := c.owner.(TypingCapable); ok { if stop, err := tc.StartTyping(ctx, chatID); err == nil { c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop) } } // Reaction — independent pipeline if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" { if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil { c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) } } // Placeholder — independent pipeline if pc, ok := c.owner.(PlaceholderCapable); ok { if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) } } } ``` **This means**: - Channels implementing `TypingCapable` (Telegram, Discord, LINE, Pico) do not need to manually call `StartTyping` + `RecordTypingStop` in `handleMessage` - Channels implementing `ReactionCapable` (Slack, OneBot) do not need to manually call `AddReaction` + `RecordTypingStop` in `handleMessage` - Channels implementing `PlaceholderCapable` (Telegram, Discord, Pico) do not need to manually send placeholder messages and call `RecordPlaceholder` in `handleMessage` - Channels only need to implement the corresponding interface; `HandleMessage` handles orchestration automatically - Channels that don't implement these interfaces are unaffected (type assertions will fail and be skipped) - `PlaceholderCapable`'s `SendPlaceholder` method internally decides whether to send based on the configured `PlaceholderConfig.Enabled`; returning `("", nil)` skips registration **Owner Injection**: Manager automatically calls `SetOwner(ch)` in `initChannel` to inject the concrete channel into BaseChannel — no manual setup required from developers. When the Agent finishes processing a message, Manager's `preSend` automatically: 1. Calls the recorded `stop()` to stop Typing 2. Calls the recorded `undo()` to undo Reaction 3. If there is a Placeholder and the channel implements `MessageEditor`, attempts to edit the Placeholder with the final reply (skipping Send) ### 3.5 Register Configuration and Gateway Integration #### Add configuration in `pkg/config/config.go` ```go type ChannelsConfig struct { // ... existing channels Matrix MatrixChannelConfig `json:"matrix"` } type MatrixChannelConfig struct { Enabled bool `json:"enabled"` HomeServer string `json:"home_server"` Token string `json:"token"` AllowFrom []string `json:"allow_from"` GroupTrigger GroupTriggerConfig `json:"group_trigger"` Placeholder PlaceholderConfig `json:"placeholder"` ReasoningChannelID string `json:"reasoning_channel_id"` } ``` #### Add entry in Manager.initChannels() ```go // In the initChannels() method of pkg/channels/manager.go if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" { m.initChannel("matrix", "Matrix") } ``` > **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config: > ```go > if cfg.UseNative { > m.initChannel("whatsapp_native", "WhatsApp Native") > } else { > m.initChannel("whatsapp", "WhatsApp") > } > ``` #### Add blank import in Gateway ```go // cmd/picoclaw/internal/gateway/helpers.go import ( _ "github.com/sipeed/picoclaw/pkg/channels/matrix" ) ``` --- ## Part 4: Core Subsystem Details ### 4.1 MessageBus **Files**: `pkg/bus/bus.go`, `pkg/bus/types.go` ```go type MessageBus struct { inbound chan InboundMessage // buffer = 64 outbound chan OutboundMessage // buffer = 64 outboundMedia chan OutboundMediaMessage // buffer = 64 done chan struct{} // Close signal closed atomic.Bool // Prevents double-close } ``` **Key Behaviors**: | Method | Behavior | |--------|----------| | `PublishInbound(ctx, msg)` | Check closed → send to inbound channel → block/timeout/close | | `ConsumeInbound(ctx)` | Read from inbound → block/close/cancel | | `PublishOutbound(ctx, msg)` | Send to outbound channel | | `SubscribeOutbound(ctx)` | Read from outbound (called by Manager dispatcher) | | `PublishOutboundMedia(ctx, msg)` | Send to outboundMedia channel | | `SubscribeOutboundMedia(ctx)` | Read from outboundMedia (called by Manager media dispatcher) | | `Close()` | CAS close → close(done) → drain all channels (**does not close the channels themselves** to avoid concurrent send-on-closed panic) | **Design Notes**: - Buffer size increased from 16 to 64 to reduce blocking under burst load - `Close()` does not close the underlying channels (only closes the `done` signal channel), because there may be concurrent `Publish` goroutines - Drain loop ensures buffered messages are not silently dropped ### 4.2 Structured Message Types **File**: `pkg/bus/types.go` ```go // Routing peer type Peer struct { Kind string `json:"kind"` // "direct" | "group" | "channel" | "" ID string `json:"id"` } // Sender identity information type SenderInfo struct { Platform string `json:"platform,omitempty"` // "telegram", "discord", ... PlatformID string `json:"platform_id,omitempty"` // Platform-native ID CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" canonical format Username string `json:"username,omitempty"` DisplayName string `json:"display_name,omitempty"` } // Inbound message type InboundMessage struct { Channel string // Source channel name SenderID string // Sender ID (prefer CanonicalID) Sender SenderInfo // Structured sender info ChatID string // Chat/room ID Content string // Message text Media []string // Media reference list (media://...) Peer Peer // Routing peer (first-class field) MessageID string // Platform message ID (first-class field) MediaScope string // Media lifecycle scope SessionKey string // Session key Metadata map[string]string // Only for channel-specific extensions } // Outbound text message type OutboundMessage struct { Channel string ChatID string Content string } // Outbound media message type OutboundMediaMessage struct { Channel string ChatID string Parts []MediaPart } // Media part type MediaPart struct { Type string // "image" | "audio" | "video" | "file" Ref string // "media://uuid" Caption string Filename string ContentType string } ``` ### 4.3 BaseChannel **File**: `pkg/channels/base.go` BaseChannel is the shared abstraction layer for all channels, providing the following capabilities: | Method/Feature | Description | |---|---| | `Name() string` | Channel name | | `IsRunning() bool` | Atomically read running state | | `SetRunning(bool)` | Atomically set running state | | `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited | | `ReasoningChannelID() string` | Reasoning chain routing target channel ID (empty = no routing) | | `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) | | `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) | | `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic | | `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction/Placeholder → publish to Bus | | `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager | | `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager | | `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction/Placeholder type assertions in HandleMessage) | **Functional Options**: ```go channels.WithMaxMessageLength(4096) // Set platform message length limit channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration channels.WithReasoningChannelID(id) // Set reasoning chain routing target channel ``` ### 4.4 Factory Registry **File**: `pkg/channels/registry.go` ```go type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init() func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager ``` The factory registry is protected by `sync.RWMutex` and registrations occur during `init()` phase (completed at process startup). Manager looks up factories by name in `initChannel()` and calls them. ### 4.5 Error Classification and Retries **Files**: `pkg/channels/errors.go`, `pkg/channels/errutil.go` #### Sentinel Errors ```go var ( ErrNotRunning = errors.New("channel not running") // Permanent: do not retry ErrRateLimit = errors.New("rate limited") // Fixed delay: retry after 1s ErrTemporary = errors.New("temporary failure") // Exponential backoff: 500ms * 2^attempt, max 8s ErrSendFailed = errors.New("send failed") // Permanent: do not retry ) ``` #### Error Classification Helpers ```go // Automatically classify based on HTTP status code func ClassifySendError(statusCode int, rawErr error) error { // 429 → ErrRateLimit // 5xx → ErrTemporary // 4xx → ErrSendFailed } // Wrap network errors as temporary func ClassifyNetError(err error) error { // → ErrTemporary } ``` #### Manager Retry Strategy (`sendWithRetry`) ``` Max retries: 3 Rate limit delay: 1 second Base backoff: 500 milliseconds Max backoff: 8 seconds Retry logic: ErrNotRunning → Fail immediately, no retry ErrSendFailed → Fail immediately, no retry ErrRateLimit → Wait 1s → retry ErrTemporary → Wait 500ms * 2^attempt (max 8s) → retry Other unknown → Wait 500ms * 2^attempt (max 8s) → retry ``` ### 4.6 Manager Orchestration **File**: `pkg/channels/manager.go` #### Per-channel Worker Architecture ```go type channelWorker struct { ch Channel // Channel instance queue chan bus.OutboundMessage // Outbound text queue (buffered 16) mediaQueue chan bus.OutboundMediaMessage // Outbound media queue (buffered 16) done chan struct{} // Text worker completion signal mediaDone chan struct{} // Media worker completion signal limiter *rate.Limiter // Per-channel rate limiter } ``` #### Per-channel Rate Limit Configuration ```go var channelRateConfig = map[string]float64{ "telegram": 20, // 20 msg/s "discord": 1, // 1 msg/s "slack": 1, // 1 msg/s "line": 10, // 10 msg/s } // Default: 10 msg/s // burst = max(1, ceil(rate/2)) ``` #### Lifecycle Management ``` StartAll: 1. Iterate registered channels → channel.Start(ctx) 2. Create channelWorker for each successfully started channel 3. Start goroutines: - runWorker (per-channel outbound text) - runMediaWorker (per-channel outbound media) - dispatchOutbound (route from bus to worker queues) - dispatchOutboundMedia (route from bus to media worker queues) - runTTLJanitor (every 10s clean up expired typing/reaction/placeholder) 4. Start shared HTTP server (if configured) StopAll: 1. Shut down shared HTTP server (5s timeout) 2. Cancel dispatcher context 3. Close text worker queues → wait for drain to complete 4. Close media worker queues → wait for drain to complete 5. Stop each channel (channel.Stop) ``` #### Typing/Reaction/Placeholder Management ```go // Manager implements PlaceholderRecorder interface func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) // Inbound side: BaseChannel.HandleMessage auto-orchestrates // BaseChannel.HandleMessage, before PublishInbound, auto-triggers via owner type assertions: // - TypingCapable.StartTyping → RecordTypingStop // - ReactionCapable.ReactToMessage → RecordReactionUndo // - PlaceholderCapable.SendPlaceholder → RecordPlaceholder // All three are independent and do not interfere with each other. Channels don't need to call these manually. // Outbound side: pre-send processing func (m *Manager) preSend(ctx, name, msg, ch) bool { key := name + ":" + msg.ChatID // 1. Stop Typing (call stored stop function) // 2. Undo Reaction (call stored undo function) // 3. Attempt to edit Placeholder (if channel implements MessageEditor) // Success → return true (skip Send) // Failure → return false (proceed with Send) } ``` Manager storage is fully separated; three pipelines do not interfere: ```go Manager { typingStops sync.Map // "channel:chatID" → typingEntry ← manages TypingCapable reactionUndos sync.Map // "channel:chatID" → reactionEntry ← manages ReactionCapable placeholders sync.Map // "channel:chatID" → placeholderEntry } ``` TTL Cleanup: - Typing stop functions: 5-minute TTL (auto-calls stop and deletes on expiry) - Reaction undo functions: 5-minute TTL (auto-calls undo and deletes on expiry) - Placeholder IDs: 10-minute TTL (deletes on expiry) - Cleanup interval: 10 seconds ### 4.7 Message Splitting **File**: `pkg/channels/split.go` `SplitMessage(content string, maxLen int) []string` Smart splitting strategy: 1. Calculate effective split point = maxLen - 10% buffer (to reserve space for code block closure) 2. Prefer splitting at newlines 3. Otherwise split at spaces/tabs 4. Detect unclosed code blocks (` ``` `) 5. If a code block is unclosed: - Attempt to extend to maxLen to include the closing fence - If the code block is too long, inject close/reopen fences (`\n```\n` + header) - Last resort: split before the code block starts ### 4.8 MediaStore **File**: `pkg/media/store.go` ```go type MediaStore interface { Store(localPath string, meta MediaMeta, scope string) (ref string, err error) Resolve(ref string) (localPath string, err error) ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error) ReleaseAll(scope string) error } ``` **FileMediaStore Implementation**: - Pure in-memory mapping, no file copy/move - Reference format: `media://x"},
{"plain text", "just text", "just text"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := markdownToHTML(tt.input)
if !strings.Contains(got, tt.contains) {
t.Fatalf("markdownToHTML(%q) = %q, want it to contain %q", tt.input, got, tt.contains)
}
})
}
}
func TestMessageContent(t *testing.T) {
richtext := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "richtext"}}
plain := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "plain"}}
defaultt := &MatrixChannel{config: config.MatrixConfig{}}
for _, c := range []*MatrixChannel{richtext, defaultt} {
mc := c.messageContent("**hi**")
if mc.Format != event.FormatHTML {
t.Errorf("format %q: expected FormatHTML, got %q", c.config.MessageFormat, mc.Format)
}
if !strings.Contains(mc.FormattedBody, "hi") {
t.Errorf("format %q: FormattedBody %q missing ", c.config.MessageFormat, mc.FormattedBody)
}
if mc.Body != "**hi**" {
t.Errorf("format %q: Body should remain plain, got %q", c.config.MessageFormat, mc.Body)
}
}
mc := plain.messageContent("**hi**")
if mc.Format != "" || mc.FormattedBody != "" {
t.Errorf("plain: expected no formatting, got format=%q formattedBody=%q", mc.Format, mc.FormattedBody)
}
}
================================================
FILE: pkg/channels/media.go
================================================
package channels
import (
"context"
"github.com/sipeed/picoclaw/pkg/bus"
)
// MediaSender is an optional interface for channels that can send
// media attachments (images, files, audio, video).
// Manager discovers channels implementing this interface via type
// assertion and routes OutboundMediaMessage to them.
type MediaSender interface {
SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error
}
================================================
FILE: pkg/channels/onebot/init.go
================================================
package onebot
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewOneBotChannel(cfg.Channels.OneBot, b)
})
}
================================================
FILE: pkg/channels/onebot/onebot.go
================================================
package onebot
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
)
type OneBotChannel struct {
*channels.BaseChannel
config config.OneBotConfig
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
dedup map[string]struct{}
dedupRing []string
dedupIdx int
mu sync.Mutex
writeMu sync.Mutex
echoCounter int64
selfID int64
pending map[string]chan json.RawMessage
pendingMu sync.Mutex
lastMessageID sync.Map
}
type oneBotRawEvent struct {
PostType string `json:"post_type"`
MessageType string `json:"message_type"`
SubType string `json:"sub_type"`
MessageID json.RawMessage `json:"message_id"`
UserID json.RawMessage `json:"user_id"`
GroupID json.RawMessage `json:"group_id"`
RawMessage string `json:"raw_message"`
Message json.RawMessage `json:"message"`
Sender json.RawMessage `json:"sender"`
SelfID json.RawMessage `json:"self_id"`
Time json.RawMessage `json:"time"`
MetaEventType string `json:"meta_event_type"`
NoticeType string `json:"notice_type"`
Echo string `json:"echo"`
RetCode json.RawMessage `json:"retcode"`
Status json.RawMessage `json:"status"`
Data json.RawMessage `json:"data"`
}
type BotStatus struct {
Online bool `json:"online"`
Good bool `json:"good"`
}
func isAPIResponse(raw json.RawMessage) bool {
if len(raw) == 0 {
return false
}
var s string
if json.Unmarshal(raw, &s) == nil {
return s == "ok" || s == "failed"
}
var bs BotStatus
if json.Unmarshal(raw, &bs) == nil {
return bs.Online || bs.Good
}
return false
}
type oneBotSender struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
Card string `json:"card"`
}
type oneBotAPIRequest struct {
Action string `json:"action"`
Params any `json:"params"`
Echo string `json:"echo,omitempty"`
}
type oneBotMessageSegment struct {
Type string `json:"type"`
Data map[string]any `json:"data"`
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom,
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
const dedupSize = 1024
return &OneBotChannel{
BaseChannel: base,
config: cfg,
dedup: make(map[string]struct{}, dedupSize),
dedupRing: make([]string, dedupSize),
dedupIdx: 0,
pending: make(map[string]chan json.RawMessage),
}, nil
}
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
go func() {
_, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{
"message_id": messageID,
"emoji_id": emojiID,
"set": set,
}, 5*time.Second)
if err != nil {
logger.DebugCF("onebot", "Failed to set emoji like", map[string]any{
"message_id": messageID,
"error": err.Error(),
})
}
}()
}
// ReactToMessage implements channels.ReactionCapable.
// It adds an emoji reaction (ID 289) to group messages and returns an undo function.
// Private messages return a no-op since reactions are only meaningful in groups.
func (c *OneBotChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
// Only react in group chats
if !strings.HasPrefix(chatID, "group:") {
return func() {}, nil
}
c.setMsgEmojiLike(messageID, 289, true)
return func() {
c.setMsgEmojiLike(messageID, 289, false)
}, nil
}
func (c *OneBotChannel) Start(ctx context.Context) error {
if c.config.WSUrl == "" {
return fmt.Errorf("OneBot ws_url not configured")
}
logger.InfoCF("onebot", "Starting OneBot channel", map[string]any{
"ws_url": c.config.WSUrl,
})
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.connect(); err != nil {
logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]any{
"error": err.Error(),
})
} else {
go c.listen()
c.fetchSelfID()
}
if c.config.ReconnectInterval > 0 {
go c.reconnectLoop()
} else {
if c.conn == nil {
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
}
}
c.SetRunning(true)
logger.InfoC("onebot", "OneBot channel started successfully")
return nil
}
func (c *OneBotChannel) connect() error {
dialer := websocket.DefaultDialer
dialer.HandshakeTimeout = 10 * time.Second
header := make(map[string][]string)
if c.config.AccessToken != "" {
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
}
conn, resp, err := dialer.Dial(c.config.WSUrl, header)
if resp != nil {
resp.Body.Close()
}
if err != nil {
return err
}
conn.SetPongHandler(func(appData string) error {
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
go c.pinger(conn)
logger.InfoC("onebot", "WebSocket connected")
return nil
}
func (c *OneBotChannel) pinger(conn *websocket.Conn) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.writeMu.Lock()
err := conn.WriteMessage(websocket.PingMessage, nil)
c.writeMu.Unlock()
if err != nil {
logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]any{
"error": err.Error(),
})
return
}
}
}
}
func (c *OneBotChannel) fetchSelfID() {
resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second)
if err != nil {
logger.WarnCF("onebot", "Failed to get_login_info", map[string]any{
"error": err.Error(),
})
return
}
type loginInfo struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
}
for _, extract := range []func() (*loginInfo, error){
func() (*loginInfo, error) {
var w struct {
Data loginInfo `json:"data"`
}
err := json.Unmarshal(resp, &w)
return &w.Data, err
},
func() (*loginInfo, error) {
var f loginInfo
err := json.Unmarshal(resp, &f)
return &f, err
},
} {
info, err := extract()
if err != nil || len(info.UserID) == 0 {
continue
}
if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 {
atomic.StoreInt64(&c.selfID, uid)
logger.InfoCF("onebot", "Bot self ID retrieved", map[string]any{
"self_id": uid,
"nickname": info.Nickname,
})
return
}
}
logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]any{
"response": string(resp),
})
}
func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.Duration) (json.RawMessage, error) {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return nil, fmt.Errorf("WebSocket not connected")
}
echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1))
ch := make(chan json.RawMessage, 1)
c.pendingMu.Lock()
c.pending[echo] = ch
c.pendingMu.Unlock()
defer func() {
c.pendingMu.Lock()
delete(c.pending, echo)
c.pendingMu.Unlock()
}()
req := oneBotAPIRequest{
Action: action,
Params: params,
Echo: echo,
}
data, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal API request: %w", err)
}
c.writeMu.Lock()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
_ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to write API request: %w", err)
}
select {
case resp := <-ch:
if resp == nil {
return nil, fmt.Errorf("API request %s: channel stopped", action)
}
return resp, nil
case <-time.After(timeout):
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
case <-c.ctx.Done():
return nil, fmt.Errorf("context canceled")
}
}
func (c *OneBotChannel) reconnectLoop() {
interval := max(time.Duration(c.config.ReconnectInterval)*time.Second, 5*time.Second)
for {
select {
case <-c.ctx.Done():
return
case <-time.After(interval):
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.InfoC("onebot", "Attempting to reconnect...")
if err := c.connect(); err != nil {
logger.ErrorCF("onebot", "Reconnect failed", map[string]any{
"error": err.Error(),
})
} else {
go c.listen()
c.fetchSelfID()
}
}
}
}
}
func (c *OneBotChannel) Stop(ctx context.Context) error {
logger.InfoC("onebot", "Stopping OneBot channel")
c.SetRunning(false)
if c.cancel != nil {
c.cancel()
}
c.pendingMu.Lock()
for echo, ch := range c.pending {
select {
case ch <- nil: // non-blocking wake for blocked sendAPIRequest goroutines
default:
}
delete(c.pending, echo)
}
c.pendingMu.Unlock()
c.mu.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.mu.Unlock()
return nil
}
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
// Check ctx before entering write path
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return fmt.Errorf("OneBot WebSocket not connected")
}
action, params, err := c.buildSendRequest(msg)
if err != nil {
return err
}
echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
req := oneBotAPIRequest{
Action: action,
Params: params,
Echo: echo,
}
data, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal OneBot request: %w", err)
}
c.writeMu.Lock()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
_ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
logger.ErrorCF("onebot", "Failed to send message", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("onebot send: %w", channels.ErrTemporary)
}
return nil
}
// SendMedia implements the channels.MediaSender interface.
func (c *OneBotChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return fmt.Errorf("OneBot WebSocket not connected")
}
store := c.GetMediaStore()
if store == nil {
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
// Build media segments
var segments []oneBotMessageSegment
for _, part := range msg.Parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("onebot", "Failed to resolve media ref", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
continue
}
var segType string
switch part.Type {
case "image":
segType = "image"
case "video":
segType = "video"
case "audio":
segType = "record"
default:
segType = "file"
}
segments = append(segments, oneBotMessageSegment{
Type: segType,
Data: map[string]any{"file": "file://" + localPath},
})
if part.Caption != "" {
segments = append(segments, oneBotMessageSegment{
Type: "text",
Data: map[string]any{"text": part.Caption},
})
}
}
if len(segments) == 0 {
return nil
}
chatID := msg.ChatID
var action, idKey string
var rawID string
if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
action, idKey, rawID = "send_group_msg", "group_id", rest
} else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
action, idKey, rawID = "send_private_msg", "user_id", rest
} else {
action, idKey, rawID = "send_private_msg", "user_id", chatID
}
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil {
return fmt.Errorf("invalid %s in chatID: %s: %w", idKey, chatID, channels.ErrSendFailed)
}
echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
req := oneBotAPIRequest{
Action: action,
Params: map[string]any{idKey: id, "message": segments},
Echo: echo,
}
data, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal OneBot request: %w", err)
}
c.writeMu.Lock()
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
_ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
logger.ErrorCF("onebot", "Failed to send media message", map[string]any{
"error": err.Error(),
})
return fmt.Errorf("onebot send media: %w", channels.ErrTemporary)
}
return nil
}
func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment {
var segments []oneBotMessageSegment
if lastMsgID, ok := c.lastMessageID.Load(chatID); ok {
if msgID, ok := lastMsgID.(string); ok && msgID != "" {
segments = append(segments, oneBotMessageSegment{
Type: "reply",
Data: map[string]any{"id": msgID},
})
}
}
segments = append(segments, oneBotMessageSegment{
Type: "text",
Data: map[string]any{"text": content},
})
return segments
}
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, any, error) {
chatID := msg.ChatID
segments := c.buildMessageSegments(chatID, msg.Content)
var action, idKey string
var rawID string
if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
action, idKey, rawID = "send_group_msg", "group_id", rest
} else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
action, idKey, rawID = "send_private_msg", "user_id", rest
} else {
action, idKey, rawID = "send_private_msg", "user_id", chatID
}
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID)
}
return action, map[string]any{idKey: id, "message": segments}, nil
}
func (c *OneBotChannel) listen() {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
return
}
for {
select {
case <-c.ctx.Done():
return
default:
_, message, err := conn.ReadMessage()
if err != nil {
logger.ErrorCF("onebot", "WebSocket read error", map[string]any{
"error": err.Error(),
})
c.mu.Lock()
if c.conn == conn {
c.conn.Close()
c.conn = nil
}
c.mu.Unlock()
return
}
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
var raw oneBotRawEvent
if err := json.Unmarshal(message, &raw); err != nil {
logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]any{
"error": err.Error(),
"payload": string(message),
})
continue
}
logger.DebugCF("onebot", "WebSocket event", map[string]any{
"length": len(message),
"post_type": raw.PostType,
"sub_type": raw.SubType,
})
if raw.Echo != "" {
c.pendingMu.Lock()
ch, ok := c.pending[raw.Echo]
c.pendingMu.Unlock()
if ok {
select {
case ch <- message:
default:
}
} else {
logger.DebugCF("onebot", "Received API response (no waiter)", map[string]any{
"echo": raw.Echo,
"status": string(raw.Status),
})
}
continue
}
if isAPIResponse(raw.Status) {
logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]any{
"status": string(raw.Status),
})
continue
}
c.handleRawEvent(&raw)
}
}
}
func parseJSONInt64(raw json.RawMessage) (int64, error) {
if len(raw) == 0 {
return 0, nil
}
var n int64
if err := json.Unmarshal(raw, &n); err == nil {
return n, nil
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return strconv.ParseInt(s, 10, 64)
}
return 0, fmt.Errorf("cannot parse as int64: %s", string(raw))
}
func parseJSONString(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
return string(raw)
}
type parseMessageResult struct {
Text string
IsBotMentioned bool
Media []string
ReplyTo string
}
func (c *OneBotChannel) parseMessageSegments(
raw json.RawMessage,
selfID int64,
store media.MediaStore,
scope string,
) parseMessageResult {
if len(raw) == 0 {
return parseMessageResult{}
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
mentioned := false
if selfID > 0 {
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
if strings.Contains(s, cqAt) {
mentioned = true
s = strings.ReplaceAll(s, cqAt, "")
s = strings.TrimSpace(s)
}
}
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
}
var segments []map[string]any
if err := json.Unmarshal(raw, &segments); err != nil {
return parseMessageResult{}
}
var textParts []string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
var mediaRefs []string
var replyTo string
// Helper to register a local file with the media store
storeFile := func(localPath, filename string) string {
if store != nil {
ref, err := store.Store(localPath, media.MediaMeta{
Filename: filename,
Source: "onebot",
}, scope)
if err == nil {
return ref
}
}
return localPath // fallback
}
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]any)
switch segType {
case "text":
if data != nil {
if t, ok := data["text"].(string); ok {
textParts = append(textParts, t)
}
}
case "at":
if data != nil && selfID > 0 {
qqVal := fmt.Sprintf("%v", data["qq"])
if qqVal == selfIDStr || qqVal == "all" {
mentioned = true
}
}
case "image", "video", "file":
if data != nil {
url, _ := data["url"].(string)
if url != "" {
defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"}
filename := defaults[segType]
if f, ok := data["file"].(string); ok && f != "" {
filename = f
} else if n, ok := data["name"].(string); ok && n != "" {
filename = n
}
localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "onebot",
})
if localPath != "" {
mediaRefs = append(mediaRefs, storeFile(localPath, filename))
textParts = append(textParts, fmt.Sprintf("[%s]", segType))
}
}
}
case "record":
if data != nil {
url, _ := data["url"].(string)
if url != "" {
localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{
LoggerPrefix: "onebot",
})
if localPath != "" {
textParts = append(textParts, "[voice]")
mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr"))
}
}
}
case "reply":
if data != nil {
if id, ok := data["id"]; ok {
replyTo = fmt.Sprintf("%v", id)
}
}
case "face":
if data != nil {
faceID, _ := data["id"]
textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID))
}
case "forward":
textParts = append(textParts, "[forward message]")
default:
}
}
return parseMessageResult{
Text: strings.TrimSpace(strings.Join(textParts, "")),
IsBotMentioned: mentioned,
Media: mediaRefs,
ReplyTo: replyTo,
}
}
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
switch raw.PostType {
case "message":
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
// Build minimal sender for allowlist check
sender := bus.SenderInfo{
Platform: "onebot",
PlatformID: strconv.FormatInt(userID, 10),
CanonicalID: identity.BuildCanonicalID("onebot", strconv.FormatInt(userID, 10)),
}
if !c.IsAllowedSender(sender) {
logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{
"user_id": userID,
})
return
}
}
c.handleMessage(raw)
case "message_sent":
logger.DebugCF("onebot", "Bot sent message event", map[string]any{
"message_type": raw.MessageType,
"message_id": parseJSONString(raw.MessageID),
})
case "meta_event":
c.handleMetaEvent(raw)
case "notice":
c.handleNoticeEvent(raw)
case "request":
logger.DebugCF("onebot", "Request event received", map[string]any{
"sub_type": raw.SubType,
})
case "":
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]any{
"echo": raw.Echo,
"status": raw.Status,
})
default:
logger.DebugCF("onebot", "Unknown post_type", map[string]any{
"post_type": raw.PostType,
})
}
}
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
if raw.MetaEventType == "lifecycle" {
logger.InfoCF("onebot", "Lifecycle event", map[string]any{"sub_type": raw.SubType})
} else if raw.MetaEventType != "heartbeat" {
logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil)
}
}
func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) {
fields := map[string]any{
"notice_type": raw.NoticeType,
"sub_type": raw.SubType,
"group_id": parseJSONString(raw.GroupID),
"user_id": parseJSONString(raw.UserID),
"message_id": parseJSONString(raw.MessageID),
}
switch raw.NoticeType {
case "group_recall", "group_increase", "group_decrease",
"friend_add", "group_admin", "group_ban":
logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields)
default:
logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields)
}
}
func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
// Parse fields from raw event
userID, err := parseJSONInt64(raw.UserID)
if err != nil {
logger.WarnCF("onebot", "Failed to parse user_id", map[string]any{
"error": err.Error(),
"raw": string(raw.UserID),
})
return
}
groupID, _ := parseJSONInt64(raw.GroupID)
selfID, _ := parseJSONInt64(raw.SelfID)
messageID := parseJSONString(raw.MessageID)
if selfID == 0 {
selfID = atomic.LoadInt64(&c.selfID)
}
// Compute scope for media store before parsing (parsing may download files)
var chatIDForScope string
switch raw.MessageType {
case "group":
chatIDForScope = "group:" + strconv.FormatInt(groupID, 10)
default:
chatIDForScope = "private:" + strconv.FormatInt(userID, 10)
}
scope := channels.BuildMediaScope("onebot", chatIDForScope, messageID)
parsed := c.parseMessageSegments(raw.Message, selfID, c.GetMediaStore(), scope)
isBotMentioned := parsed.IsBotMentioned
content := raw.RawMessage
if content == "" {
content = parsed.Text
} else if selfID > 0 {
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
if strings.Contains(content, cqAt) {
isBotMentioned = true
content = strings.ReplaceAll(content, cqAt, "")
content = strings.TrimSpace(content)
}
}
if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") {
content = parsed.Text
}
var sender oneBotSender
if len(raw.Sender) > 0 {
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
logger.WarnCF("onebot", "Failed to parse sender", map[string]any{
"error": err.Error(),
"sender": string(raw.Sender),
})
}
}
if c.isDuplicate(messageID) {
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{
"message_id": messageID,
})
return
}
if content == "" {
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]any{
"message_id": messageID,
})
return
}
senderID := strconv.FormatInt(userID, 10)
var chatID string
var peer bus.Peer
metadata := map[string]string{}
if parsed.ReplyTo != "" {
metadata["reply_to_message_id"] = parsed.ReplyTo
}
switch raw.MessageType {
case "private":
chatID = "private:" + senderID
peer = bus.Peer{Kind: "direct", ID: senderID}
case "group":
groupIDStr := strconv.FormatInt(groupID, 10)
chatID = "group:" + groupIDStr
peer = bus.Peer{Kind: "group", ID: groupIDStr}
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(sender.UserID)
if senderUserID > 0 {
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
}
if sender.Card != "" {
metadata["sender_name"] = sender.Card
} else if sender.Nickname != "" {
metadata["sender_name"] = sender.Nickname
}
respond, strippedContent := c.ShouldRespondInGroup(isBotMentioned, content)
if !respond {
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{
"sender": senderID,
"group": groupIDStr,
"is_mentioned": isBotMentioned,
"content": truncate(content, 100),
})
return
}
content = strippedContent
default:
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]any{
"type": raw.MessageType,
"message_id": messageID,
"user_id": userID,
})
return
}
logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]any{
"sender": senderID,
"chat_id": chatID,
"message_id": messageID,
"length": len(content),
"content": truncate(content, 100),
"media_count": len(parsed.Media),
})
if sender.Nickname != "" {
metadata["nickname"] = sender.Nickname
}
c.lastMessageID.Store(chatID, messageID)
senderInfo := bus.SenderInfo{
Platform: "onebot",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("onebot", senderID),
DisplayName: sender.Nickname,
}
if !c.IsAllowedSender(senderInfo) {
logger.DebugCF("onebot", "Message rejected by allowlist (senderInfo)", map[string]any{
"sender": senderID,
})
return
}
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
if messageID == "" || messageID == "0" {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
if _, exists := c.dedup[messageID]; exists {
return true
}
if old := c.dedupRing[c.dedupIdx]; old != "" {
delete(c.dedup, old)
}
c.dedupRing[c.dedupIdx] = messageID
c.dedup[messageID] = struct{}{}
c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing)
return false
}
func truncate(s string, n int) string {
runes := []rune(s)
if len(runes) <= n {
return s
}
return string(runes[:n]) + "..."
}
================================================
FILE: pkg/channels/pico/init.go
================================================
package pico
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewPicoChannel(cfg.Channels.Pico, b)
})
}
================================================
FILE: pkg/channels/pico/pico.go
================================================
package pico
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
// picoConn represents a single WebSocket connection.
type picoConn struct {
id string
conn *websocket.Conn
sessionID string
writeMu sync.Mutex
closed atomic.Bool
}
// writeJSON sends a JSON message to the connection with write locking.
func (pc *picoConn) writeJSON(v any) error {
if pc.closed.Load() {
return fmt.Errorf("connection closed")
}
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
return pc.conn.WriteJSON(v)
}
// close closes the connection.
func (pc *picoConn) close() {
if pc.closed.CompareAndSwap(false, true) {
pc.conn.Close()
}
}
// PicoChannel implements the native Pico Protocol WebSocket channel.
// It serves as the reference implementation for all optional capability interfaces.
type PicoChannel struct {
*channels.BaseChannel
config config.PicoConfig
upgrader websocket.Upgrader
connections sync.Map // connID → *picoConn
connCount atomic.Int32
ctx context.Context
cancel context.CancelFunc
}
// NewPicoChannel creates a new Pico Protocol channel.
func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
if cfg.Token == "" {
return nil, fmt.Errorf("pico token is required")
}
base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
allowOrigins := cfg.AllowOrigins
checkOrigin := func(r *http.Request) bool {
if len(allowOrigins) == 0 {
return true // allow all if not configured
}
origin := r.Header.Get("Origin")
for _, allowed := range allowOrigins {
if allowed == "*" || allowed == origin {
return true
}
}
return false
}
return &PicoChannel{
BaseChannel: base,
config: cfg,
upgrader: websocket.Upgrader{
CheckOrigin: checkOrigin,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
}, nil
}
// Start implements Channel.
func (c *PicoChannel) Start(ctx context.Context) error {
logger.InfoC("pico", "Starting Pico Protocol channel")
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
logger.InfoC("pico", "Pico Protocol channel started")
return nil
}
// Stop implements Channel.
func (c *PicoChannel) Stop(ctx context.Context) error {
logger.InfoC("pico", "Stopping Pico Protocol channel")
c.SetRunning(false)
// Close all connections
c.connections.Range(func(key, value any) bool {
if pc, ok := value.(*picoConn); ok {
pc.close()
}
c.connections.Delete(key)
return true
})
if c.cancel != nil {
c.cancel()
}
logger.InfoC("pico", "Pico Protocol channel stopped")
return nil
}
// WebhookPath implements channels.WebhookHandler.
func (c *PicoChannel) WebhookPath() string { return "/pico/" }
// ServeHTTP implements http.Handler for the shared HTTP server.
func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/pico")
switch {
case path == "/ws" || path == "/ws/":
c.handleWebSocket(w, r)
default:
http.NotFound(w, r)
}
}
// Send implements Channel — sends a message to the appropriate WebSocket connection.
func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": msg.Content,
})
return c.broadcastToSession(msg.ChatID, outMsg)
}
// EditMessage implements channels.MessageEditor.
func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
outMsg := newMessage(TypeMessageUpdate, map[string]any{
"message_id": messageID,
"content": content,
})
return c.broadcastToSession(chatID, outMsg)
}
// StartTyping implements channels.TypingCapable.
func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
startMsg := newMessage(TypeTypingStart, nil)
if err := c.broadcastToSession(chatID, startMsg); err != nil {
return func() {}, err
}
return func() {
stopMsg := newMessage(TypeTypingStop, nil)
c.broadcastToSession(chatID, stopMsg)
}, nil
}
// SendPlaceholder implements channels.PlaceholderCapable.
// It sends a placeholder message via the Pico Protocol that will later be
// edited to the actual response via EditMessage (channels.MessageEditor).
func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
if !c.config.Placeholder.Enabled {
return "", nil
}
text := c.config.Placeholder.Text
if text == "" {
text = "Thinking... 💭"
}
msgID := uuid.New().String()
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": text,
"message_id": msgID,
})
if err := c.broadcastToSession(chatID, outMsg); err != nil {
return "", err
}
return msgID, nil
}
// broadcastToSession sends a message to all connections with a matching session.
func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error {
// chatID format: "pico:%s", escaped))
}
for i, code := range codeBlocks.codes {
escaped := escapeHTML(code)
text = strings.ReplaceAll(
text,
fmt.Sprintf("\x00CB%d\x00", i),
fmt.Sprintf("%s", escaped),
)
}
return text
}
type codeBlockMatch struct {
text string
codes []string
}
func extractCodeBlocks(text string) codeBlockMatch {
matches := reCodeBlock.FindAllStringSubmatch(text, -1)
codes := make([]string, 0, len(matches))
for _, match := range matches {
codes = append(codes, match[1])
}
i := 0
text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string {
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
i++
return placeholder
})
return codeBlockMatch{text: text, codes: codes}
}
type inlineCodeMatch struct {
text string
codes []string
}
func extractInlineCodes(text string) inlineCodeMatch {
matches := reInlineCode.FindAllStringSubmatch(text, -1)
codes := make([]string, 0, len(matches))
for _, match := range matches {
codes = append(codes, match[1])
}
i := 0
text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string {
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
i++
return placeholder
})
return inlineCodeMatch{text: text, codes: codes}
}
func escapeHTML(text string) string {
text = strings.ReplaceAll(text, "&", "&")
text = strings.ReplaceAll(text, "<", "<")
text = strings.ReplaceAll(text, ">", ">")
return text
}
================================================
FILE: pkg/channels/telegram/telegram.go
================================================
package telegram
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"time"
"github.com/mymmrac/telego"
th "github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
)
var (
reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+([^\n]+)`)
reBlockquote = regexp.MustCompile(`^>\s*(.*)$`)
reLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
reBoldStar = regexp.MustCompile(`\*\*(.+?)\*\*`)
reBoldUnder = regexp.MustCompile(`__(.+?)__`)
reItalic = regexp.MustCompile(`_([^_]+)_`)
reStrike = regexp.MustCompile(`~~(.+?)~~`)
reListItem = regexp.MustCompile(`^[-*]\s+`)
reCodeBlock = regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```")
reInlineCode = regexp.MustCompile("`([^`]+)`")
)
type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
bh *th.BotHandler
config *config.Config
chatIDs map[string]int64
ctx context.Context
cancel context.CancelFunc
registerFunc func(context.Context, []commands.Definition) error
commandRegCancel context.CancelFunc
}
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
var opts []telego.BotOption
telegramCfg := cfg.Channels.Telegram
if telegramCfg.Proxy != "" {
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
if parseErr != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
}
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}))
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
// Use environment proxy if configured
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}))
}
if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
opts = append(opts, telego.WithAPIServer(baseURL))
}
opts = append(opts, telego.WithLogger(logger.NewLogger("telego")))
bot, err := telego.NewBot(telegramCfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
base := channels.NewBaseChannel(
"telegram",
telegramCfg,
bus,
telegramCfg.AllowFrom,
channels.WithMaxMessageLength(4000),
channels.WithGroupTrigger(telegramCfg.GroupTrigger),
channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID),
)
return &TelegramChannel{
BaseChannel: base,
bot: bot,
config: cfg,
chatIDs: make(map[string]int64),
}, nil
}
func (c *TelegramChannel) Start(ctx context.Context) error {
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
if err != nil {
c.cancel()
return fmt.Errorf("failed to start long polling: %w", err)
}
bh, err := th.NewBotHandler(c.bot, updates)
if err != nil {
c.cancel()
return fmt.Errorf("failed to create bot handler: %w", err)
}
c.bh = bh
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.handleMessage(ctx, &message)
}, th.AnyMessage())
c.SetRunning(true)
logger.InfoCF("telegram", "Telegram bot connected", map[string]any{
"username": c.bot.Username(),
})
c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions())
go func() {
if err = bh.Start(); err != nil {
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
"error": err.Error(),
})
}
}()
return nil
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
logger.InfoC("telegram", "Stopping Telegram bot...")
c.SetRunning(false)
// Stop the bot handler
if c.bh != nil {
_ = c.bh.StopWithContext(ctx)
}
// Cancel our context (stops long polling)
if c.cancel != nil {
c.cancel()
}
if c.commandRegCancel != nil {
c.commandRegCancel()
}
return nil
}
func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
if err != nil {
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
if msg.Content == "" {
return nil
}
// The Manager already splits messages to ≤4000 chars (WithMaxMessageLength),
// so msg.Content is guaranteed to be within that limit. We still need to
// check if HTML expansion pushes it beyond Telegram's 4096-char API limit.
replyToID := msg.ReplyToMessageID
queue := []string{msg.Content}
for len(queue) > 0 {
chunk := queue[0]
queue = queue[1:]
content := parseContent(chunk, useMarkdownV2)
if len([]rune(content)) > 4096 {
runeChunk := []rune(chunk)
ratio := float64(len(runeChunk)) / float64(len([]rune(content)))
smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin
// Guarantee progress: if estimated length is >= chunk length, force it smaller
if smallerLen >= len(runeChunk) {
smallerLen = len(runeChunk) - 1
}
if smallerLen <= 0 {
if err := c.sendChunk(ctx, sendChunkParams{
chatID: chatID,
threadID: threadID,
content: content,
replyToID: replyToID,
mdFallback: chunk,
useMarkdownV2: useMarkdownV2,
}); err != nil {
return err
}
replyToID = ""
continue
}
// Use the estimated smaller length as a guide for SplitMessage.
// SplitMessage will find natural break points (newlines/spaces) and respect code blocks.
subChunks := channels.SplitMessage(chunk, smallerLen)
// Safety fallback: If SplitMessage failed to shorten the chunk, force a manual hard split.
if len(subChunks) == 1 && subChunks[0] == chunk {
part1 := string(runeChunk[:smallerLen])
part2 := string(runeChunk[smallerLen:])
subChunks = []string{part1, part2}
}
// Filter out empty chunks to avoid sending empty messages to Telegram.
nonEmpty := make([]string, 0, len(subChunks))
for _, s := range subChunks {
if s != "" {
nonEmpty = append(nonEmpty, s)
}
}
// Push sub-chunks back to the front of the queue
queue = append(nonEmpty, queue...)
continue
}
if err := c.sendChunk(ctx, sendChunkParams{
chatID: chatID,
threadID: threadID,
content: content,
replyToID: replyToID,
mdFallback: chunk,
useMarkdownV2: useMarkdownV2,
}); err != nil {
return err
}
// Only the first chunk should be a reply; subsequent chunks are normal messages.
replyToID = ""
}
return nil
}
type sendChunkParams struct {
chatID int64
threadID int
content string
replyToID string
mdFallback string
useMarkdownV2 bool
}
// sendChunk sends a single HTML/MarkdownV2 message, falling back to the original
// markdown as plain text on parse failure so users never see raw HTML/MarkdownV2 tags.
func (c *TelegramChannel) sendChunk(
ctx context.Context,
params sendChunkParams,
) error {
tgMsg := tu.Message(tu.ID(params.chatID), params.content)
tgMsg.MessageThreadID = params.threadID
if params.useMarkdownV2 {
tgMsg.WithParseMode(telego.ModeMarkdownV2)
} else {
tgMsg.WithParseMode(telego.ModeHTML)
}
if params.replyToID != "" {
if mid, parseErr := strconv.Atoi(params.replyToID); parseErr == nil {
tgMsg.ReplyParameters = &telego.ReplyParameters{
MessageID: mid,
}
}
}
if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil {
logParseFailed(err, params.useMarkdownV2)
tgMsg.Text = params.mdFallback
tgMsg.ParseMode = ""
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
return fmt.Errorf("telegram send: %w", channels.ErrTemporary)
}
}
return nil
}
// maxTypingDuration limits how long the typing indicator can run.
// Prevents endless typing when the LLM fails/hangs and preSend never invokes cancel.
// Matches channels.Manager's typingStopTTL (5 min) so behavior is consistent.
const maxTypingDuration = 5 * time.Minute
// StartTyping implements channels.TypingCapable.
// It sends ChatAction(typing) immediately and then repeats every 4 seconds
// (Telegram's typing indicator expires after ~5s) in a background goroutine.
// The returned stop function is idempotent and cancels the goroutine.
// The goroutine also exits automatically after maxTypingDuration if cancel is
// never called (e.g. when the LLM fails or times out without publishing).
func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
cid, threadID, err := parseTelegramChatID(chatID)
if err != nil {
return func() {}, err
}
action := tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)
action.MessageThreadID = threadID
// Send the first typing action immediately
_ = c.bot.SendChatAction(ctx, action)
typingCtx, cancel := context.WithCancel(ctx)
// Cap lifetime so the goroutine cannot run indefinitely if cancel is never called
maxCtx, maxCancel := context.WithTimeout(typingCtx, maxTypingDuration)
go func() {
defer maxCancel()
ticker := time.NewTicker(4 * time.Second)
defer ticker.Stop()
for {
select {
case <-maxCtx.Done():
return
case <-ticker.C:
a := tu.ChatAction(tu.ID(cid), telego.ChatActionTyping)
a.MessageThreadID = threadID
_ = c.bot.SendChatAction(typingCtx, a)
}
}
}()
return cancel, nil
}
// EditMessage implements channels.MessageEditor.
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
cid, _, err := parseTelegramChatID(chatID)
if err != nil {
return err
}
mid, err := strconv.Atoi(messageID)
if err != nil {
return err
}
parsedContent := parseContent(content, useMarkdownV2)
editMsg := tu.EditMessageText(tu.ID(cid), mid, parsedContent)
if useMarkdownV2 {
editMsg.WithParseMode(telego.ModeMarkdownV2)
} else {
editMsg.WithParseMode(telego.ModeHTML)
}
_, err = c.bot.EditMessageText(ctx, editMsg)
if err != nil {
logParseFailed(err, useMarkdownV2)
_, err = c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(cid), mid, content))
}
return err
}
// SendPlaceholder implements channels.PlaceholderCapable.
// It sends a placeholder message (e.g. "Thinking... 💭") that will later be
// edited to the actual response via EditMessage (channels.MessageEditor).
func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
phCfg := c.config.Channels.Telegram.Placeholder
if !phCfg.Enabled {
return "", nil
}
text := phCfg.Text
if text == "" {
text = "Thinking... 💭"
}
cid, threadID, err := parseTelegramChatID(chatID)
if err != nil {
return "", err
}
phMsg := tu.Message(tu.ID(cid), text)
phMsg.MessageThreadID = threadID
pMsg, err := c.bot.SendMessage(ctx, phMsg)
if err != nil {
return "", err
}
return fmt.Sprintf("%d", pMsg.MessageID), nil
}
// SendMedia implements the channels.MediaSender interface.
func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
if err != nil {
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
store := c.GetMediaStore()
if store == nil {
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
for _, part := range msg.Parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("telegram", "Failed to resolve media ref", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
continue
}
file, err := os.Open(localPath)
if err != nil {
logger.ErrorCF("telegram", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
continue
}
switch part.Type {
case "image":
params := &telego.SendPhotoParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Photo: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendPhoto(ctx, params)
if err != nil && strings.Contains(err.Error(), "PHOTO_INVALID_DIMENSIONS") {
if _, seekErr := file.Seek(0, io.SeekStart); seekErr != nil {
file.Close()
return fmt.Errorf("telegram rewind media after photo failure: %w", channels.ErrTemporary)
}
docParams := &telego.SendDocumentParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Document: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendDocument(ctx, docParams)
}
case "audio":
params := &telego.SendAudioParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Audio: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendAudio(ctx, params)
case "video":
params := &telego.SendVideoParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Video: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendVideo(ctx, params)
default: // "file" or unknown types
params := &telego.SendDocumentParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Document: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendDocument(ctx, params)
}
file.Close()
if err != nil {
logger.ErrorCF("telegram", "Failed to send media", map[string]any{
"type": part.Type,
"error": err.Error(),
})
return fmt.Errorf("telegram send media: %w", channels.ErrTemporary)
}
}
return nil
}
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
if message == nil {
return fmt.Errorf("message is nil")
}
user := message.From
if user == nil {
return fmt.Errorf("message sender (user) is nil")
}
platformID := fmt.Sprintf("%d", user.ID)
sender := bus.SenderInfo{
Platform: "telegram",
PlatformID: platformID,
CanonicalID: identity.BuildCanonicalID("telegram", platformID),
Username: user.Username,
DisplayName: user.FirstName,
}
// check allowlist to avoid downloading attachments for rejected users
if !c.IsAllowedSender(sender) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{
"user_id": platformID,
})
return nil
}
chatID := message.Chat.ID
c.chatIDs[platformID] = chatID
content := ""
mediaPaths := []string{}
chatIDStr := fmt.Sprintf("%d", chatID)
messageIDStr := fmt.Sprintf("%d", message.MessageID)
scope := channels.BuildMediaScope("telegram", chatIDStr, messageIDStr)
// Helper to register a local file with the media store
storeMedia := func(localPath, filename string) string {
if store := c.GetMediaStore(); store != nil {
ref, err := store.Store(localPath, media.MediaMeta{
Filename: filename,
Source: "telegram",
}, scope)
if err == nil {
return ref
}
}
return localPath // fallback: use raw path
}
if message.Text != "" {
content += message.Text
}
if message.Caption != "" {
if content != "" {
content += "\n"
}
content += message.Caption
}
if len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
mediaPaths = append(mediaPaths, storeMedia(photoPath, "photo.jpg"))
if content != "" {
content += "\n"
}
content += "[image: photo]"
}
}
if message.Voice != nil {
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" {
mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg"))
if content != "" {
content += "\n"
}
content += "[voice]"
}
}
if message.Audio != nil {
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" {
mediaPaths = append(mediaPaths, storeMedia(audioPath, "audio.mp3"))
if content != "" {
content += "\n"
}
content += "[audio]"
}
}
if message.Document != nil {
docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" {
mediaPaths = append(mediaPaths, storeMedia(docPath, "document"))
if content != "" {
content += "\n"
}
content += "[file]"
}
}
if content == "" {
content = "[empty message]"
}
// In group chats, apply unified group trigger filtering
if message.Chat.Type != "private" {
isMentioned := c.isBotMentioned(message)
if isMentioned {
content = c.stripBotMention(content)
}
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
return nil
}
content = cleaned
}
// For forum topics, embed the thread ID as "chatID/threadID" so replies
// route to the correct topic and each topic gets its own session.
// Only forum groups (IsForum) are handled; regular group reply threads
// must share one session per group.
compositeChatID := fmt.Sprintf("%d", chatID)
threadID := message.MessageThreadID
if message.Chat.IsForum && threadID != 0 {
compositeChatID = fmt.Sprintf("%d/%d", chatID, threadID)
}
logger.DebugCF("telegram", "Received message", map[string]any{
"sender_id": sender.CanonicalID,
"chat_id": compositeChatID,
"thread_id": threadID,
"preview": utils.Truncate(content, 50),
})
peerKind := "direct"
peerID := fmt.Sprintf("%d", user.ID)
if message.Chat.Type != "private" {
peerKind = "group"
peerID = compositeChatID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
messageID := fmt.Sprintf("%d", message.MessageID)
metadata := map[string]string{
"user_id": fmt.Sprintf("%d", user.ID),
"username": user.Username,
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
// Set parent_peer metadata for per-topic agent binding.
if message.Chat.IsForum && threadID != 0 {
metadata["parent_peer_kind"] = "topic"
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
}
c.HandleMessage(c.ctx,
peer,
messageID,
platformID,
compositeChatID,
content,
mediaPaths,
metadata,
sender,
)
return nil
}
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{
"error": err.Error(),
})
return ""
}
return c.downloadFileWithInfo(file, ".jpg")
}
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
if file.FilePath == "" {
return ""
}
url := c.bot.FileDownloadURL(file.FilePath)
logger.DebugCF("telegram", "File URL", map[string]any{"url": url})
// Use FilePath as filename for better identification
filename := file.FilePath + ext
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "telegram",
})
}
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
logger.ErrorCF("telegram", "Failed to get file", map[string]any{
"error": err.Error(),
})
return ""
}
return c.downloadFileWithInfo(file, ext)
}
func parseContent(text string, useMarkdownV2 bool) string {
if useMarkdownV2 {
return markdownToTelegramMarkdownV2(text)
}
return markdownToTelegramHTML(text)
}
// parseTelegramChatID splits "chatID/threadID" into its components.
// Returns threadID=0 when no "/" is present (non-forum messages).
func parseTelegramChatID(chatID string) (int64, int, error) {
idx := strings.Index(chatID, "/")
if idx == -1 {
cid, err := strconv.ParseInt(chatID, 10, 64)
return cid, 0, err
}
cid, err := strconv.ParseInt(chatID[:idx], 10, 64)
if err != nil {
return 0, 0, err
}
tid, err := strconv.Atoi(chatID[idx+1:])
if err != nil {
return 0, 0, fmt.Errorf("invalid thread ID in chat ID %q: %w", chatID, err)
}
return cid, tid, nil
}
func logParseFailed(err error, useMarkdownV2 bool) {
parsingName := "HTML"
if useMarkdownV2 {
parsingName = "MarkdownV2"
}
logger.ErrorCF("telegram",
fmt.Sprintf("%s parse failed, falling back to plain text", parsingName),
map[string]any{
"error": err.Error(),
},
)
}
// isBotMentioned checks if the bot is mentioned in the message via entities.
func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
text, entities := telegramEntityTextAndList(message)
if text == "" || len(entities) == 0 {
return false
}
botUsername := ""
if c.bot != nil {
botUsername = c.bot.Username()
}
runes := []rune(text)
for _, entity := range entities {
entityText, ok := telegramEntityText(runes, entity)
if !ok {
continue
}
switch entity.Type {
case telego.EntityTypeMention:
if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) {
return true
}
case telego.EntityTypeTextMention:
if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) {
return true
}
case telego.EntityTypeBotCommand:
if isBotCommandEntityForThisBot(entityText, botUsername) {
return true
}
}
}
return false
}
func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) {
if message.Text != "" {
return message.Text, message.Entities
}
return message.Caption, message.CaptionEntities
}
func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) {
if entity.Offset < 0 || entity.Length <= 0 {
return "", false
}
end := entity.Offset + entity.Length
if entity.Offset >= len(runes) || end > len(runes) {
return "", false
}
return string(runes[entity.Offset:end]), true
}
func isBotCommandEntityForThisBot(entityText, botUsername string) bool {
if !strings.HasPrefix(entityText, "/") {
return false
}
command := strings.TrimPrefix(entityText, "/")
if command == "" {
return false
}
at := strings.IndexRune(command, '@')
if at == -1 {
// A bare /command delivered to this bot is intended for this bot.
return true
}
mentionUsername := command[at+1:]
if mentionUsername == "" || botUsername == "" {
return false
}
return strings.EqualFold(mentionUsername, botUsername)
}
// stripBotMention removes the @bot mention from the content.
func (c *TelegramChannel) stripBotMention(content string) string {
botUsername := c.bot.Username()
if botUsername == "" {
return content
}
// Case-insensitive replacement
re := regexp.MustCompile(`(?i)@` + regexp.QuoteMeta(botUsername))
content = re.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
================================================
FILE: pkg/channels/telegram/telegram_dispatch_test.go
================================================
package telegram
import (
"context"
"testing"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
)
func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "/new",
MessageID: 9,
Chat: telego.Chat{
ID: 123,
Type: "private",
},
From: &telego.User{
ID: 42,
FirstName: "Alice",
},
}
if err := ch.handleMessage(context.Background(), msg); err != nil {
t.Fatalf("handleMessage error: %v", err)
}
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "telegram" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
================================================
FILE: pkg/channels/telegram/telegram_group_command_filter_test.go
================================================
package telegram
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
type getMeCaller struct {
username string
}
func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) {
if strings.HasSuffix(url, "/getMe") {
result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username)
return &ta.Response{Ok: true, Result: []byte(result)}, nil
}
return &ta.Response{Ok: true, Result: []byte("true")}, nil
}
func newTestTelegramBot(t *testing.T, username string) *telego.Bot {
t.Helper()
token := "123456:" + strings.Repeat("a", 35)
bot, err := telego.NewBot(token,
telego.WithAPICaller(getMeCaller{username: username}),
telego.WithDiscardLogger(),
)
if err != nil {
t.Fatalf("NewBot error: %v", err)
}
return bot
}
func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) {
t.Helper()
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil,
channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}),
),
bot: newTestTelegramBot(t, botUsername),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
return ch, messageBus
}
func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
tests := []struct {
name string
text string
wantForwarded bool
wantContent string
}{
{
name: "command with bot username",
text: "/new@testbot",
wantForwarded: true,
wantContent: "/new",
},
{
name: "bare command",
text: "/new",
wantForwarded: true,
wantContent: "/new",
},
{
name: "command for another bot",
text: "/new@otherbot",
wantForwarded: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ch, messageBus := newGroupMentionOnlyChannel(t, "testbot")
msg := &telego.Message{
Text: tc.text,
Entities: []telego.MessageEntity{{
Type: telego.EntityTypeBotCommand,
Offset: 0,
Length: len([]rune(tc.text)),
}},
MessageID: 42,
Chat: telego.Chat{
ID: 123,
Type: "group",
},
From: &telego.User{
ID: 7,
FirstName: "Alice",
},
}
if err := ch.handleMessage(context.Background(), msg); err != nil {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
defer cancel()
select {
case <-ctx.Done():
if tc.wantForwarded {
t.Fatal("timeout waiting for message to be forwarded")
return
}
case inbound, ok := <-messageBus.InboundChan():
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
}
return
}
}
})
}
}
func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) {
ch, _ := newGroupMentionOnlyChannel(t, "testbot")
msg := &telego.Message{
Text: "@testbot hello",
Entities: []telego.MessageEntity{{
Type: telego.EntityTypeMention,
Offset: 0,
Length: len("@testbot"),
}},
}
if !ch.isBotMentioned(msg) {
t.Fatal("expected mention entity to be treated as bot mention")
}
}
================================================
FILE: pkg/channels/telegram/telegram_test.go
================================================
package telegram
import (
"context"
"encoding/json"
"errors"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
const testToken = "1234567890:aaaabbbbaaaabbbbaaaabbbbaaaabbbbccc"
// stubCaller implements ta.Caller for testing.
type stubCaller struct {
calls []stubCall
callFn func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error)
}
type stubCall struct {
URL string
Data *ta.RequestData
}
func (s *stubCaller) Call(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
s.calls = append(s.calls, stubCall{URL: url, Data: data})
return s.callFn(ctx, url, data)
}
// stubConstructor implements ta.RequestConstructor for testing.
type stubConstructor struct{}
type multipartCall struct {
Parameters map[string]string
FileSizes map[string]int
}
func (s *stubConstructor) JSONRequest(parameters any) (*ta.RequestData, error) {
b, err := json.Marshal(parameters)
if err != nil {
return nil, err
}
return &ta.RequestData{
ContentType: "application/json",
BodyRaw: b,
}, nil
}
func (s *stubConstructor) MultipartRequest(
parameters map[string]string,
files map[string]ta.NamedReader,
) (*ta.RequestData, error) {
return &ta.RequestData{}, nil
}
type multipartRecordingConstructor struct {
stubConstructor
calls []multipartCall
}
func (s *multipartRecordingConstructor) MultipartRequest(
parameters map[string]string,
files map[string]ta.NamedReader,
) (*ta.RequestData, error) {
call := multipartCall{
Parameters: make(map[string]string, len(parameters)),
FileSizes: make(map[string]int, len(files)),
}
for k, v := range parameters {
call.Parameters[k] = v
}
for field, file := range files {
if file == nil {
continue
}
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
call.FileSizes[field] = len(data)
}
s.calls = append(s.calls, call)
return &ta.RequestData{}, nil
}
// successResponse returns a ta.Response that telego will treat as a successful SendMessage.
func successResponse(t *testing.T) *ta.Response {
t.Helper()
msg := &telego.Message{MessageID: 1}
b, err := json.Marshal(msg)
require.NoError(t, err)
return &ta.Response{Ok: true, Result: b}
}
// newTestChannel creates a TelegramChannel with a mocked bot for unit testing.
func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
return newTestChannelWithConstructor(t, caller, &stubConstructor{})
}
func newTestChannelWithConstructor(
t *testing.T,
caller *stubCaller,
constructor ta.RequestConstructor,
) *TelegramChannel {
t.Helper()
bot, err := telego.NewBot(testToken,
telego.WithAPICaller(caller),
telego.WithRequestConstructor(constructor),
telego.WithDiscardLogger(),
)
require.NoError(t, err)
base := channels.NewBaseChannel("telegram", nil, nil, nil,
channels.WithMaxMessageLength(4000),
)
base.SetRunning(true)
return &TelegramChannel{
BaseChannel: base,
bot: bot,
chatIDs: make(map[string]int64),
config: config.DefaultConfig(),
}
}
func TestSendMedia_ImageFallbacksToDocumentOnInvalidDimensions(t *testing.T) {
constructor := &multipartRecordingConstructor{}
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
switch {
case strings.Contains(url, "sendPhoto"):
return nil, errors.New(`api: 400 "Bad Request: PHOTO_INVALID_DIMENSIONS"`)
case strings.Contains(url, "sendDocument"):
return successResponse(t), nil
default:
t.Fatalf("unexpected API call: %s", url)
return nil, nil
}
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
localPath := filepath.Join(tmpDir, "woodstock-en-10s.png")
content := []byte("fake-png-content")
require.NoError(t, os.WriteFile(localPath, content, 0o644))
ref, err := store.Store(
localPath,
media.MediaMeta{Filename: "woodstock-en-10s.png", ContentType: "image/png"},
"scope-1",
)
require.NoError(t, err)
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{{
Type: "image",
Ref: ref,
Caption: "caption",
}},
})
require.NoError(t, err)
require.Len(t, caller.calls, 2)
assert.Contains(t, caller.calls[0].URL, "sendPhoto")
assert.Contains(t, caller.calls[1].URL, "sendDocument")
require.Len(t, constructor.calls, 2)
assert.Equal(t, len(content), constructor.calls[0].FileSizes["photo"])
assert.Equal(t, len(content), constructor.calls[1].FileSizes["document"])
assert.Equal(t, "caption", constructor.calls[1].Parameters["caption"])
}
func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) {
constructor := &multipartRecordingConstructor{}
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return nil, errors.New("api: 500 \"server exploded\"")
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
localPath := filepath.Join(tmpDir, "image.png")
require.NoError(t, os.WriteFile(localPath, []byte("fake-png-content"), 0o644))
ref, err := store.Store(localPath, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{{
Type: "image",
Ref: ref,
}},
})
require.Error(t, err)
assert.ErrorIs(t, err, channels.ErrTemporary)
require.Len(t, caller.calls, 1)
assert.Contains(t, caller.calls[0].URL, "sendPhoto")
require.Len(t, constructor.calls, 1)
assert.NotContains(t, caller.calls[0].URL, "sendDocument")
}
func TestSend_EmptyContent(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
t.Fatal("SendMessage should not be called for empty content")
return nil, nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "",
})
assert.NoError(t, err)
assert.Empty(t, caller.calls, "no API calls should be made for empty content")
}
func TestSend_ShortMessage_SingleCall(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello, world!",
})
assert.NoError(t, err)
assert.Len(t, caller.calls, 1, "short message should result in exactly one SendMessage call")
}
func TestSend_LongMessage_SingleCall(t *testing.T) {
// With WithMaxMessageLength(4000), the Manager pre-splits messages before
// they reach Send(). A message at exactly 4000 chars should go through
// as a single SendMessage call (no re-split needed since HTML expansion
// won't exceed 4096 for plain text).
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
longContent := strings.Repeat("a", 4000)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: longContent,
})
assert.NoError(t, err)
assert.Len(t, caller.calls, 1, "pre-split message within limit should result in one SendMessage call")
}
func TestSend_HTMLFallback_PerChunk(t *testing.T) {
callCount := 0
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
callCount++
// Fail on odd calls (HTML attempt), succeed on even calls (plain text fallback)
if callCount%2 == 1 {
return nil, errors.New("Bad Request: can't parse entities")
}
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello **world**",
})
assert.NoError(t, err)
// One short message → 1 HTML attempt (fail) + 1 plain text fallback (success) = 2 calls
assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text fallback")
}
func TestSend_HTMLFallback_BothFail(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return nil, errors.New("send failed")
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello",
})
assert.Error(t, err)
assert.True(t, errors.Is(err, channels.ErrTemporary), "error should wrap ErrTemporary")
assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text attempt")
}
func TestSend_LongMessage_HTMLFallback_StopsOnError(t *testing.T) {
// With a long message that gets split into 2 chunks, if both HTML and
// plain text fail on the first chunk, Send should return early.
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return nil, errors.New("send failed")
},
}
ch := newTestChannel(t, caller)
longContent := strings.Repeat("x", 4001)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: longContent,
})
assert.Error(t, err)
// Should fail on the first chunk (2 calls: HTML + fallback), never reaching the second chunk.
assert.Equal(t, 2, len(caller.calls), "should stop after first chunk fails both HTML and plain text")
}
func TestSend_MarkdownShortButHTMLLong_MultipleCalls(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
// Create markdown whose length is <= 4000 but whose HTML expansion is much longer.
// "**a** " (6 chars) becomes "a " (9 chars) in HTML, so repeating it many times
// yields HTML that exceeds Telegram's limit while markdown stays within it.
markdownContent := strings.Repeat("**a** ", 600) // 3600 chars markdown, HTML ~5400+ chars
assert.LessOrEqual(t, len([]rune(markdownContent)), 4000, "markdown content must not exceed chunk size")
htmlExpanded := markdownToTelegramHTML(markdownContent)
assert.Greater(
t, len([]rune(htmlExpanded)), 4096,
"HTML expansion must exceed Telegram limit for this test to be meaningful",
)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: markdownContent,
})
assert.NoError(t, err)
assert.Greater(
t, len(caller.calls), 1,
"markdown-short but HTML-long message should be split into multiple SendMessage calls",
)
}
func TestSend_HTMLOverflow_WordBoundary(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
// We want to force a split near index ~2600 while keeping markdown length <= 4000.
// Prefix of 430 bold units (6 chars each) = 2580 chars.
// Expansion per unit is +3 chars when converted to HTML, so 2580 + 430*3 = 3870.
prefix := strings.Repeat("**a** ", 430)
targetWord := "TARGETWORDTHATSTAYSTOGETHER"
// Suffix of 230 bold units (6 chars each) = 1380 chars.
// Total markdown length: 2580 (prefix) + 27 (target word) + 1380 (suffix) = 3987 <= 4000.
// HTML expansion adds ~3 chars per bold unit: (430 + 230)*3 = 1980 extra chars,
// so total HTML length comfortably exceeds 4096.
suffix := strings.Repeat(" **b**", 230)
content := prefix + targetWord + suffix
// Ensure the test content matches the intended boundary conditions.
assert.LessOrEqual(t, len([]rune(content)), 4000, "markdown content must not exceed chunk size for this test")
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "123456",
Content: content,
})
assert.NoError(t, err)
foundFullWord := false
for i, call := range caller.calls {
var params map[string]any
err := json.Unmarshal(call.Data.BodyRaw, ¶ms)
require.NoError(t, err)
text, _ := params["text"].(string)
hasWord := strings.Contains(text, targetWord)
t.Logf("Chunk %d length: %d, contains target word: %v", i, len(text), hasWord)
if hasWord {
foundFullWord = true
break
}
}
assert.True(t, foundFullWord, "The target word should not be split between chunks")
}
func TestSend_NotRunning(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
t.Fatal("should not be called")
return nil, nil
},
}
ch := newTestChannel(t, caller)
ch.SetRunning(false)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello",
})
assert.ErrorIs(t, err, channels.ErrNotRunning)
assert.Empty(t, caller.calls)
}
func TestSend_InvalidChatID(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
t.Fatal("should not be called")
return nil, nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "not-a-number",
Content: "Hello",
})
assert.Error(t, err)
assert.True(t, errors.Is(err, channels.ErrSendFailed), "error should wrap ErrSendFailed")
assert.Empty(t, caller.calls)
}
func TestParseTelegramChatID_Plain(t *testing.T) {
cid, tid, err := parseTelegramChatID("12345")
assert.NoError(t, err)
assert.Equal(t, int64(12345), cid)
assert.Equal(t, 0, tid)
}
func TestParseTelegramChatID_NegativeGroup(t *testing.T) {
cid, tid, err := parseTelegramChatID("-1001234567890")
assert.NoError(t, err)
assert.Equal(t, int64(-1001234567890), cid)
assert.Equal(t, 0, tid)
}
func TestParseTelegramChatID_WithThreadID(t *testing.T) {
cid, tid, err := parseTelegramChatID("-1001234567890/42")
assert.NoError(t, err)
assert.Equal(t, int64(-1001234567890), cid)
assert.Equal(t, 42, tid)
}
func TestParseTelegramChatID_GeneralTopic(t *testing.T) {
cid, tid, err := parseTelegramChatID("-100123/1")
assert.NoError(t, err)
assert.Equal(t, int64(-100123), cid)
assert.Equal(t, 1, tid)
}
func TestParseTelegramChatID_Invalid(t *testing.T) {
_, _, err := parseTelegramChatID("not-a-number")
assert.Error(t, err)
}
func TestParseTelegramChatID_InvalidThreadID(t *testing.T) {
_, _, err := parseTelegramChatID("-100123/not-a-thread")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid thread ID")
}
func TestSend_WithForumThreadID(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "-1001234567890/42",
Content: "Hello from topic",
})
assert.NoError(t, err)
assert.Len(t, caller.calls, 1)
}
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "hello from topic",
MessageID: 10,
MessageThreadID: 42,
Chat: telego.Chat{
ID: -1001234567890,
Type: "supergroup",
IsForum: true,
},
From: &telego.User{
ID: 7,
FirstName: "Alice",
},
}
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
// Peer ID should include thread ID for session key isolation
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
// Parent peer metadata should be set for agent binding
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
}
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "regular group message",
MessageID: 11,
Chat: telego.Chat{
ID: -100999,
Type: "group",
},
From: &telego.User{
ID: 8,
FirstName: "Bob",
},
}
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// Plain chatID without thread suffix
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (no thread suffix)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
}
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
// In regular groups, reply threads set MessageThreadID to the original
// message ID. This should NOT trigger per-thread session isolation.
msg := &telego.Message{
Text: "reply in thread",
MessageID: 20,
MessageThreadID: 15,
Chat: telego.Chat{
ID: -100999,
Type: "supergroup",
IsForum: false,
},
From: &telego.User{
ID: 9,
FirstName: "Carol",
},
}
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// chatID should NOT include thread suffix for non-forum groups
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (shared session for whole group)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
}
================================================
FILE: pkg/channels/telegram/testdata/md2_all_formats.txt
================================================
*bold \*text*
_italic \*text_
__underline__
~strikethrough~
||spoiler||
*bold _italic bold ~italic bold strikethrough ||italic bold strikethrough spoiler||~ __underline italic bold___ bold*
[inline URL](http://www.example.com/)
[inline mention of a user](tg://user?id=123456789)





`inline fixed-width code`
```
pre-formatted fixed-width code block
```
```python
pre-formatted fixed-width code block written in the Python programming language
```
>Block quotation started
>Block quotation continued
>Block quotation continued
>Block quotation continued
>The last line of the block quotation
**>The expandable block quotation started right after the previous block quotation
>It is separated from the previous block quotation by an empty bold entity
>Expandable block quotation continued
>Hidden by default part of the expandable block quotation started
>Expandable block quotation continued
>The last line of the expandable block quotation with the expandability mark||
================================================
FILE: pkg/channels/webhook.go
================================================
package channels
import "net/http"
// WebhookHandler is an optional interface for channels that receive messages
// via HTTP webhooks. Manager discovers channels implementing this interface
// and registers them on the shared HTTP server.
type WebhookHandler interface {
// WebhookPath returns the path to mount this handler on the shared server.
// Examples: "/webhook/line", "/webhook/wecom"
WebhookPath() string
http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request)
}
// HealthChecker is an optional interface for channels that expose
// a health check endpoint on the shared HTTP server.
type HealthChecker interface {
HealthPath() string
HealthHandler(w http.ResponseWriter, r *http.Request)
}
================================================
FILE: pkg/channels/wecom/aibot.go
================================================
package wecom
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
// responseURLHTTPClient is a shared HTTP client for posting to WeCom response_url.
// Reusing it enables connection pooling across replies.
var responseURLHTTPClient = &http.Client{Timeout: 15 * time.Second}
// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人)
type WeComAIBotChannel struct {
*channels.BaseChannel
config config.WeComAIBotConfig
ctx context.Context
cancel context.CancelFunc
streamTasks map[string]*streamTask // streamID -> task (for poll lookups)
chatTasks map[string][]*streamTask // chatID -> in-flight tasks queue (FIFO)
taskMu sync.RWMutex
}
// streamTask represents a streaming task for AI Bot.
//
// Mutable fields (Finished, StreamClosed, StreamClosedAt) must be read/written
// while holding WeComAIBotChannel.taskMu. Immutable fields (StreamID, ChatID,
// ResponseURL, Question, CreatedTime, Deadline, answerCh, ctx, cancel) are set
// once at creation and never modified, so they are safe to read without a lock.
type streamTask struct {
// immutable after creation
StreamID string
ChatID string // used by Send() to find this task
ResponseURL string // temporary URL for proactive reply (valid 1 hour, use once)
Question string
CreatedTime time.Time
Deadline time.Time // ~30s, we close the stream here and switch to response_url
answerCh chan string // receives agent reply from Send()
ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine
cancel context.CancelFunc // call on task removal to cancel ctx
// mutable — guarded by WeComAIBotChannel.taskMu
StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url
StreamClosedAt time.Time // set when StreamClosed becomes true; used for accelerated cleanup
Finished bool // fully done
}
// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot
// Ref: https://developer.work.weixin.qq.com/document/path/100719
type WeComAIBotMessage struct {
MsgID string `json:"msgid"`
AIBotID string `json:"aibotid"`
ChatID string `json:"chatid"` // only for group chat
ChatType string `json:"chattype"` // "single" or "group"
From struct {
UserID string `json:"userid"`
} `json:"from"`
ResponseURL string `json:"response_url"` // temporary URL for proactive reply
MsgType string `json:"msgtype"`
// text message
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
// stream polling refresh
Stream *struct {
ID string `json:"id"`
} `json:"stream,omitempty"`
// image message
Image *struct {
URL string `json:"url"`
} `json:"image,omitempty"`
// mixed message (text + image)
Mixed *struct {
MsgItem []struct {
MsgType string `json:"msgtype"`
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
Image *struct {
URL string `json:"url"`
} `json:"image,omitempty"`
} `json:"msg_item"`
} `json:"mixed,omitempty"`
// event field
Event *struct {
EventType string `json:"eventtype"`
} `json:"event,omitempty"`
}
// WeComAIBotMsgItemImage holds the image payload inside a stream message item.
type WeComAIBotMsgItemImage struct {
Base64 string `json:"base64"`
MD5 string `json:"md5"`
}
// WeComAIBotMsgItem is a single item inside a stream's msg_item list.
type WeComAIBotMsgItem struct {
MsgType string `json:"msgtype"`
Image *WeComAIBotMsgItemImage `json:"image,omitempty"`
}
// WeComAIBotStreamInfo represents the detailed stream content in streaming responses.
type WeComAIBotStreamInfo struct {
ID string `json:"id"`
Finish bool `json:"finish"`
Content string `json:"content,omitempty"`
MsgItem []WeComAIBotMsgItem `json:"msg_item,omitempty"`
}
// WeComAIBotStreamResponse represents the streaming response format
type WeComAIBotStreamResponse struct {
MsgType string `json:"msgtype"`
Stream WeComAIBotStreamInfo `json:"stream"`
}
// WeComAIBotEncryptedResponse represents the encrypted response wrapper
// Fields match WXBizJsonMsgCrypt.generate() in Python SDK
type WeComAIBotEncryptedResponse struct {
Encrypt string `json:"encrypt"`
MsgSignature string `json:"msgsignature"`
Timestamp string `json:"timestamp"`
Nonce string `json:"nonce"`
}
// NewWeComAIBotChannel creates a WeCom AI Bot channel instance.
// If cfg.BotID and cfg.Secret are both set, it returns a WeComAIBotWSChannel
// using the WebSocket long-connection API.
// Otherwise it returns the webhook-mode WeComAIBotChannel (requires Token +
// EncodingAESKey).
func NewWeComAIBotChannel(
cfg config.WeComAIBotConfig,
messageBus *bus.MessageBus,
) (channels.Channel, error) {
// WebSocket long-connection mode takes priority when BotID + Secret are set.
if cfg.BotID != "" && cfg.Secret != "" {
logger.InfoC("wecom_aibot", "BotID and Secret provided, using WebSocket mode")
return newWeComAIBotWSChannel(cfg, messageBus)
}
// Webhook (short-connection) mode.
if cfg.Token == "" || cfg.EncodingAESKey == "" {
return nil, fmt.Errorf(
"WeCom AI Bot requires either (bot_id + secret) for WebSocket mode " +
"or (token + encoding_aes_key) for webhook mode")
}
if cfg.ProcessingMessage == "" {
cfg.ProcessingMessage = config.DefaultWeComAIBotProcessingMessage
}
base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom,
channels.WithMaxMessageLength(2048),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
return &WeComAIBotChannel{
BaseChannel: base,
config: cfg,
streamTasks: make(map[string]*streamTask),
chatTasks: make(map[string][]*streamTask),
}, nil
}
// Name returns the channel name
func (c *WeComAIBotChannel) Name() string {
return "wecom_aibot"
}
// Start initializes the WeCom AI Bot channel
func (c *WeComAIBotChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel...")
c.ctx, c.cancel = context.WithCancel(ctx)
// Start cleanup goroutine for old tasks
go c.cleanupLoop()
c.SetRunning(true)
logger.InfoC("wecom_aibot", "WeCom AI Bot channel started")
return nil
}
// Stop gracefully stops the WeCom AI Bot channel
func (c *WeComAIBotChannel) Stop(ctx context.Context) error {
logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel...")
if c.cancel != nil {
c.cancel()
}
c.SetRunning(false)
logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped")
return nil
}
// Send delivers the agent reply into the active streamTask for msg.ChatID.
// It writes into the earliest unfinished task in the queue (FIFO per chatID).
// If the stream has already closed (deadline passed), it posts directly to response_url.
func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
c.taskMu.Lock()
queue := c.chatTasks[msg.ChatID]
// Only compact Finished tasks at the head of the queue.
// Tasks that are Finished in the middle are NOT removed here: doing a full
// scan on every Send() call would be O(n) and is unnecessary given that
// removeTask() always splices the task out of the queue immediately.
// Any Finished task left stranded in the middle (e.g. due to an unexpected
// code path) will be collected by cleanupOldTasks.
for len(queue) > 0 && queue[0].Finished {
queue = queue[1:]
}
c.chatTasks[msg.ChatID] = queue
var task *streamTask
var streamClosed bool
var responseURL string
if len(queue) > 0 {
task = queue[0]
// Read mutable fields while holding c.taskMu to avoid data races.
streamClosed = task.StreamClosed
responseURL = task.ResponseURL
}
c.taskMu.Unlock()
if task == nil {
logger.DebugCF(
"wecom_aibot",
"Send: no active task for chat (may have timed out)",
map[string]any{
"chat_id": msg.ChatID,
},
)
return nil
}
if streamClosed {
// Stream already ended with a "please wait" notice; send the real reply via response_url.
// Note: task.StreamID and task.ChatID are immutable, safe to read without a lock.
logger.InfoCF("wecom_aibot", "Sending reply via response_url", map[string]any{
"stream_id": task.StreamID,
"chat_id": msg.ChatID,
})
if responseURL != "" {
if err := c.sendViaResponseURL(responseURL, msg.Content); err != nil {
logger.ErrorCF("wecom_aibot", "Failed to send via response_url", map[string]any{
"error": err,
"stream_id": task.StreamID,
})
c.removeTask(task)
return fmt.Errorf("response_url delivery failed: %w", channels.ErrSendFailed)
}
} else {
logger.WarnCF("wecom_aibot", "Stream closed but no response_url available", map[string]any{
"stream_id": task.StreamID,
})
}
c.removeTask(task)
return nil
}
// Stream still open: deliver via answerCh for the next poll response.
select {
case task.answerCh <- msg.Content:
case <-task.ctx.Done():
// Task was canceled (cleanup removed it); silently drop the reply.
return nil
case <-ctx.Done():
return ctx.Err()
}
return nil
}
// WebhookPath returns the path for registering on the shared HTTP server
func (c *WeComAIBotChannel) WebhookPath() string {
if c.config.WebhookPath == "" {
return "/webhook/wecom-aibot"
}
return c.config.WebhookPath
}
// ServeHTTP implements http.Handler for the shared HTTP server
func (c *WeComAIBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.handleWebhook(w, r)
}
// HealthPath returns the health check endpoint path
func (c *WeComAIBotChannel) HealthPath() string {
return c.WebhookPath() + "/health"
}
// HealthHandler handles health check requests
func (c *WeComAIBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
c.handleHealth(w, r)
}
// handleWebhook handles incoming webhook requests from WeCom AI Bot
func (c *WeComAIBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Log all incoming requests for debugging
logger.DebugCF("wecom_aibot", "Received webhook request", map[string]any{
"method": r.Method,
"path": r.URL.Path,
"query": r.URL.RawQuery,
})
switch r.Method {
case http.MethodGet:
// URL verification
c.handleVerification(ctx, w, r)
case http.MethodPost:
// Message callback
c.handleMessageCallback(ctx, w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleVerification handles the URL verification request from WeCom
func (c *WeComAIBotChannel) handleVerification(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) {
msgSignature := r.URL.Query().Get("msg_signature")
timestamp := r.URL.Query().Get("timestamp")
nonce := r.URL.Query().Get("nonce")
echostr := r.URL.Query().Get("echostr")
logger.DebugCF("wecom_aibot", "URL verification request", map[string]any{
"msg_signature": msgSignature,
"timestamp": timestamp,
"nonce": nonce,
})
// Verify signature
if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
logger.ErrorC("wecom_aibot", "Signature verification failed")
http.Error(w, "Signature verification failed", http.StatusUnauthorized)
return
}
// Decrypt echostr
// For WeCom AI Bot (智能机器人), receiveid should be empty string
decrypted, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to decrypt echostr", map[string]any{
"error": err,
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
// Remove BOM and whitespace as per WeCom documentation
decrypted = strings.TrimPrefix(decrypted, "\ufeff")
decrypted = strings.TrimSpace(decrypted)
logger.InfoC("wecom_aibot", "URL verification successful")
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write([]byte(decrypted))
}
// handleMessageCallback handles incoming messages from WeCom AI Bot
func (c *WeComAIBotChannel) handleMessageCallback(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) {
msgSignature := r.URL.Query().Get("msg_signature")
timestamp := r.URL.Query().Get("timestamp")
nonce := r.URL.Query().Get("nonce")
// Read request body (limit to 4 MB to prevent memory exhaustion)
const maxBodySize = 4 << 20 // 4 MB
body, err := io.ReadAll(io.LimitReader(r.Body, maxBodySize+1))
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to read request body", map[string]any{
"error": err,
})
http.Error(w, "Failed to read body", http.StatusBadRequest)
return
}
if len(body) > maxBodySize {
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
return
}
// Parse JSON body to get encrypted message
// Format: {"encrypt": "base64_encrypted_string"}
var encryptedMsg struct {
Encrypt string `json:"encrypt"`
}
if unmarshalErr := json.Unmarshal(body, &encryptedMsg); unmarshalErr != nil {
logger.ErrorCF("wecom_aibot", "Failed to parse JSON body", map[string]any{
"error": unmarshalErr,
"body": string(body),
})
http.Error(w, "Failed to parse JSON", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
logger.ErrorC("wecom_aibot", "Signature verification failed")
http.Error(w, "Signature verification failed", http.StatusUnauthorized)
return
}
// Decrypt message
// For WeCom AI Bot (智能机器人), receiveid is empty string
decrypted, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to decrypt message", map[string]any{
"error": err,
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
// Parse decrypted JSON message
var msg WeComAIBotMessage
if unmarshalErr := json.Unmarshal([]byte(decrypted), &msg); unmarshalErr != nil {
logger.ErrorCF("wecom_aibot", "Failed to parse decrypted JSON", map[string]any{
"error": unmarshalErr,
"decrypted": decrypted,
})
http.Error(w, "Failed to parse message", http.StatusInternalServerError)
return
}
logger.DebugCF("wecom_aibot", "Decrypted message", map[string]any{
"msgtype": msg.MsgType,
})
// Process the message and get streaming response
response := c.processMessage(ctx, msg, timestamp, nonce)
// Check if response is empty (e.g. due to unsupported message type)
if response == "" {
response = c.encryptEmptyResponse(timestamp, nonce)
}
// Return encrypted JSON response
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}
// processMessage processes the received message and returns encrypted response
func (c *WeComAIBotChannel) processMessage(
ctx context.Context,
msg WeComAIBotMessage,
timestamp, nonce string,
) string {
logger.DebugCF("wecom_aibot", "Processing message", map[string]any{
"msgtype": msg.MsgType,
})
switch msg.MsgType {
case "text":
return c.handleTextMessage(ctx, msg, timestamp, nonce)
case "stream":
return c.handleStreamMessage(ctx, msg, timestamp, nonce)
case "image":
return c.handleImageMessage(ctx, msg, timestamp, nonce)
case "mixed":
return c.handleMixedMessage(ctx, msg, timestamp, nonce)
case "event":
return c.handleEventMessage(ctx, msg, timestamp, nonce)
default:
logger.WarnCF("wecom_aibot", "Unsupported message type", map[string]any{
"msgtype": msg.MsgType,
})
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
Stream: WeComAIBotStreamInfo{
ID: c.generateStreamID(),
Finish: true,
Content: "Unsupported message type: " + msg.MsgType,
},
})
}
}
// handleTextMessage handles text messages by starting a new streaming task
func (c *WeComAIBotChannel) handleTextMessage(
ctx context.Context,
msg WeComAIBotMessage,
timestamp, nonce string,
) string {
if msg.Text == nil {
logger.ErrorC("wecom_aibot", "text message missing text field")
return c.encryptEmptyResponse(timestamp, nonce)
}
content := msg.Text.Content
userID := msg.From.UserID
if userID == "" {
userID = "unknown"
}
// chatID: group chat uses chatid, single chat uses userid
chatID := msg.ChatID
if chatID == "" {
chatID = userID
}
streamID := c.generateStreamID()
// WeCom stops sending stream-refresh callbacks after 6 minutes.
// Set a slightly shorter deadline so we can send a timeout notice before it gives up.
deadline := time.Now().Add(30 * time.Second)
// Each task gets its own context derived from the channel lifetime context.
// Canceling taskCancel interrupts the agent goroutine when the task is removed.
taskCtx, taskCancel := context.WithCancel(c.ctx)
task := &streamTask{
StreamID: streamID,
ChatID: chatID,
ResponseURL: msg.ResponseURL,
Question: content,
CreatedTime: time.Now(),
Deadline: deadline,
Finished: false,
answerCh: make(chan string, 1),
ctx: taskCtx,
cancel: taskCancel,
}
c.taskMu.Lock()
c.streamTasks[streamID] = task
c.chatTasks[chatID] = append(c.chatTasks[chatID], task)
c.taskMu.Unlock()
// Publish to agent asynchronously; agent will call Send() with reply.
// Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed.
go func() {
sender := bus.SenderInfo{
Platform: "wecom_aibot",
PlatformID: userID,
CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID),
DisplayName: userID,
}
peerKind := "direct"
if msg.ChatType == "group" {
peerKind = "group"
}
peer := bus.Peer{Kind: peerKind, ID: chatID}
metadata := map[string]string{
"channel": "wecom_aibot",
"chat_type": msg.ChatType,
"msg_type": "text",
"msgid": msg.MsgID,
"aibotid": msg.AIBotID,
"stream_id": streamID,
"response_url": msg.ResponseURL,
}
c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID,
content, nil, metadata, sender)
}()
// Return first streaming response immediately (finish=false, content empty)
return c.getStreamResponse(task, timestamp, nonce)
}
// handleStreamMessage handles stream polling requests
func (c *WeComAIBotChannel) handleStreamMessage(
ctx context.Context,
msg WeComAIBotMessage,
timestamp, nonce string,
) string {
if msg.Stream == nil {
logger.ErrorC("wecom_aibot", "Stream message missing stream field")
return c.encryptEmptyResponse(timestamp, nonce)
}
streamID := msg.Stream.ID
c.taskMu.RLock()
task, exists := c.streamTasks[streamID]
c.taskMu.RUnlock()
if !exists {
logger.DebugCF(
"wecom_aibot",
"Stream task not found (may be from previous session)",
map[string]any{
"stream_id": streamID,
},
)
return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
Stream: WeComAIBotStreamInfo{
ID: streamID,
Finish: true,
Content: "Task not found or already finished. Please resend your message to start a new session.",
},
})
}
// Get next response
return c.getStreamResponse(task, timestamp, nonce)
}
// handleImageMessage handles image messages
func (c *WeComAIBotChannel) handleImageMessage(
ctx context.Context,
msg WeComAIBotMessage,
timestamp, nonce string,
) string {
logger.WarnC("wecom_aibot", "Image message type not yet fully implemented")
if msg.Image == nil {
logger.ErrorC("wecom_aibot", "Image message missing image field")
return c.encryptEmptyResponse(timestamp, nonce)
}
imageURL := msg.Image.URL
// For now, just acknowledge receipt without echoing the image
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
Stream: WeComAIBotStreamInfo{
ID: c.generateStreamID(),
Finish: true,
Content: fmt.Sprintf(
"Image received (URL: %s), but image messages are not yet supported",
imageURL,
),
},
})
}
// handleMixedMessage handles mixed (text + image) messages
func (c *WeComAIBotChannel) handleMixedMessage(
ctx context.Context,
msg WeComAIBotMessage,
timestamp, nonce string,
) string {
logger.WarnC("wecom_aibot", "Mixed message type not yet fully implemented")
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
Stream: WeComAIBotStreamInfo{
ID: c.generateStreamID(),
Finish: true,
Content: "Mixed message type is not yet supported",
},
})
}
// handleEventMessage handles event messages
func (c *WeComAIBotChannel) handleEventMessage(
ctx context.Context,
msg WeComAIBotMessage,
timestamp, nonce string,
) string {
eventType := ""
if msg.Event != nil {
eventType = msg.Event.EventType
}
logger.DebugCF("wecom_aibot", "Received event", map[string]any{
"event_type": eventType,
})
// Send welcome message when user opens the chat window
if eventType == "enter_chat" && c.config.WelcomeMessage != "" {
streamID := c.generateStreamID()
return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
Stream: WeComAIBotStreamInfo{
ID: streamID,
Finish: true,
Content: c.config.WelcomeMessage,
},
})
}
return c.encryptEmptyResponse(timestamp, nonce)
}
// getStreamResponse gets the next streaming response for a task.
// - If agent replied: return finish=true with the real answer.
// - If deadline passed: return finish=true with a "please wait" notice, keep task alive for response_url.
// - Otherwise: return finish=false (empty), client will poll again.
func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce string) string {
var content string
var finish bool
var closeStreamOnly bool // close stream but do NOT remove task (response_url still pending)
select {
case answer := <-task.answerCh:
// Agent replied before deadline — normal finish.
content = answer
finish = true
default:
if time.Now().After(task.Deadline) {
// Deadline reached: close the stream with a notice, then wait for agent via response_url.
content = c.config.ProcessingMessage
finish = true
closeStreamOnly = true
logger.InfoCF(
"wecom_aibot",
"Stream deadline reached, switching to response_url mode",
map[string]any{
"stream_id": task.StreamID,
"chat_id": task.ChatID,
"response_url": task.ResponseURL != "",
},
)
}
// else: still waiting, return finish=false
}
if finish && !closeStreamOnly {
// Normal finish: remove from all maps.
c.removeTask(task)
} else if closeStreamOnly {
// Mark stream as closed and remove from streamTasks under a single lock
// to keep StreamClosed/StreamClosedAt consistent with map membership.
c.taskMu.Lock()
task.StreamClosed = true
task.StreamClosedAt = time.Now()
delete(c.streamTasks, task.StreamID)
c.taskMu.Unlock()
}
response := WeComAIBotStreamResponse{
MsgType: "stream",
Stream: WeComAIBotStreamInfo{
ID: task.StreamID,
Finish: finish,
Content: content,
},
}
return c.encryptResponse(task.StreamID, timestamp, nonce, response)
}
// removeTask removes a task from both streamTasks and chatTasks, marks it finished,
// and cancels its context to interrupt the associated agent goroutine.
func (c *WeComAIBotChannel) removeTask(task *streamTask) {
// Cancel first so the agent goroutine stops as soon as possible,
// before we acquire the write lock.
task.cancel()
c.taskMu.Lock()
task.Finished = true // written under c.taskMu, consistent with all readers
delete(c.streamTasks, task.StreamID)
queue := c.chatTasks[task.ChatID]
for i, t := range queue {
if t == task {
c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...)
break
}
}
if len(c.chatTasks[task.ChatID]) == 0 {
delete(c.chatTasks, task.ChatID)
}
c.taskMu.Unlock()
}
// sendViaResponseURL posts a markdown reply to the WeCom response_url.
// response_url is valid for 1 hour and can only be used once per callback.
// Returned errors are wrapped with channels.ErrRateLimit, channels.ErrTemporary,
// or channels.ErrSendFailed so the manager can apply the right retry policy.
func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) error {
payload := map[string]any{
"msgtype": "markdown",
"markdown": map[string]string{
"content": content,
},
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, responseURL, bytes.NewBuffer(body))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := responseURLHTTPClient.Do(req)
if err != nil {
return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
const maxErrBody = 64 << 10 // 64 KB is more than enough for any error response
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxErrBody))
if err != nil {
return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err)
}
switch {
case resp.StatusCode == http.StatusTooManyRequests:
return fmt.Errorf("response_url rate limited (%d): %s: %w",
resp.StatusCode, respBody, channels.ErrRateLimit)
case resp.StatusCode >= 500:
return fmt.Errorf("response_url server error (%d): %s: %w",
resp.StatusCode, respBody, channels.ErrTemporary)
default:
return fmt.Errorf("response_url returned %d: %s: %w",
resp.StatusCode, respBody, channels.ErrSendFailed)
}
}
// encryptResponse encrypts a streaming response
func (c *WeComAIBotChannel) encryptResponse(
streamID, timestamp, nonce string,
response WeComAIBotStreamResponse,
) string {
// Marshal response to JSON
plaintext, err := json.Marshal(response)
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to marshal response", map[string]any{
"error": err,
})
return ""
}
logger.DebugCF("wecom_aibot", "Encrypting response", map[string]any{
"stream_id": streamID,
"finish": response.Stream.Finish,
"preview": utils.Truncate(response.Stream.Content, 100),
})
// Encrypt message
encrypted, err := c.encryptMessage(string(plaintext), "")
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to encrypt message", map[string]any{
"error": err,
})
return ""
}
// Generate signature
signature := computeSignature(c.config.Token, timestamp, nonce, encrypted)
// Build encrypted response
encryptedResp := WeComAIBotEncryptedResponse{
Encrypt: encrypted,
MsgSignature: signature,
Timestamp: timestamp,
Nonce: nonce,
}
respJSON, err := json.Marshal(encryptedResp)
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to marshal encrypted response", map[string]any{
"error": err,
})
return ""
}
logger.DebugCF("wecom_aibot", "Response encrypted", map[string]any{
"stream_id": streamID,
})
return string(respJSON)
}
// encryptEmptyResponse returns a minimal valid encrypted response
func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string {
// Construct a zero-value stream response and encrypt it so that
// WeCom always receives a syntactically valid encrypted JSON object.
emptyResp := WeComAIBotStreamResponse{}
return c.encryptResponse("", timestamp, nonce, emptyResp)
}
// encryptMessage encrypts a plain text message for WeCom AI Bot
func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) {
aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
if err != nil {
return "", err
}
frame, err := packWeComFrame(plaintext, receiveid)
if err != nil {
return "", err
}
// PKCS7 padding then AES-CBC encrypt
paddedFrame := pkcs7Pad(frame, blockSize)
ciphertext, err := encryptAESCBC(aesKey, paddedFrame)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// func (c *WeComAIBotChannel) downloadAndDecryptImage(
// ctx context.Context,
// imageURL string,
// ) ([]byte, error) {
// // Download image
// req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// client := &http.Client{
// Timeout: 15 * time.Second,
// }
// resp, err := client.Do(req)
// if err != nil {
// return nil, fmt.Errorf("failed to download image: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return nil, fmt.Errorf("download failed with status: %d", resp.StatusCode)
// }
// // Limit image download to 20 MB to prevent memory exhaustion
// const maxImageSize = 20 << 20 // 20 MB
// encryptedData, err := io.ReadAll(io.LimitReader(resp.Body, maxImageSize+1))
// if err != nil {
// return nil, fmt.Errorf("failed to read image data: %w", err)
// }
// if len(encryptedData) > maxImageSize {
// return nil, fmt.Errorf("image too large (exceeds %d MB)", maxImageSize>>20)
// }
// logger.DebugCF("wecom_aibot", "Image downloaded", map[string]any{
// "size": len(encryptedData),
// })
// // Decode AES key
// aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
// if err != nil {
// return nil, err
// }
// // Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped)
// decryptedData, err := decryptAESCBC(aesKey, encryptedData)
// if err != nil {
// return nil, fmt.Errorf("failed to decrypt image: %w", err)
// }
// logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{
// "size": len(decryptedData),
// })
// return decryptedData, nil
// }
// generateRandomID generates a cryptographically random alphanumeric ID of
// length n. Used for stream IDs and WebSocket request IDs.
func generateRandomID(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, n)
for i := range b {
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b[i] = letters[num.Int64()]
}
return string(b)
}
// generateStreamID generates a random 10-character stream ID (webhook mode).
func (c *WeComAIBotChannel) generateStreamID() string {
return generateRandomID(10)
}
// cleanupLoop periodically cleans up old streaming tasks
func (c *WeComAIBotChannel) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.cleanupOldTasks()
case <-c.ctx.Done():
return
}
}
}
// cleanupOldTasks removes tasks that have exceeded their expected lifetime:
// - Active tasks (in streamTasks): cleaned up after 1 hour (response_url validity window).
// - StreamClosed tasks (in chatTasks only): cleaned up after streamClosedGracePeriod.
// These tasks are waiting for the agent to call Send() via response_url. If the agent
// crashes or times out without calling Send(), we must not let them accumulate indefinitely.
// The grace period is generous enough to cover typical LLM latency but far shorter than 1 hour,
// preventing chatTasks from filling up when many requests time out in quick succession.
const (
streamClosedGracePeriod = 10 * time.Minute // max wait for agent after stream closes
taskMaxLifetime = 1 * time.Hour // absolute max (≈ response_url validity)
)
func (c *WeComAIBotChannel) cleanupOldTasks() {
c.taskMu.Lock()
defer c.taskMu.Unlock()
now := time.Now()
cutoff := now.Add(-taskMaxLifetime)
for id, task := range c.streamTasks {
if task.CreatedTime.Before(cutoff) {
delete(c.streamTasks, id)
task.cancel() // interrupt agent goroutine still waiting for LLM
queue := c.chatTasks[task.ChatID]
for i, t := range queue {
if t == task {
c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...)
break
}
}
if len(c.chatTasks[task.ChatID]) == 0 {
delete(c.chatTasks, task.ChatID)
}
logger.DebugCF("wecom_aibot", "Cleaned up expired task", map[string]any{
"stream_id": id,
})
}
}
// Clean up StreamClosed tasks from chatTasks.
// Two expiry conditions are checked:
// 1. Absolute expiry: task was created more than taskMaxLifetime ago.
// 2. Grace expiry: stream closed more than streamClosedGracePeriod ago
// (agent had enough time to reply; it is not coming back).
for chatID, queue := range c.chatTasks {
filtered := queue[:0]
for i, t := range queue {
absoluteExpired := t.CreatedTime.Before(cutoff)
graceExpired := t.StreamClosed &&
!t.StreamClosedAt.IsZero() &&
t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod))
if t.Finished {
// Finished tasks should have been removed by removeTask().
// Finding one here (especially not at position 0) means an
// unexpected code path left it stranded, causing the queue to
// grow silently. Log a warning so it is visible, then drop it.
if i > 0 {
logger.WarnCF("wecom_aibot",
"Found stranded Finished task in the middle of chatTasks queue; "+
"this should not happen — removeTask() should have spliced it out",
map[string]any{
"chat_id": chatID,
"stream_id": t.StreamID,
"position": i,
})
}
// The task is already finished; its context was already canceled
// by removeTask(), so no further action is required.
continue
} else if !absoluteExpired && !graceExpired {
filtered = append(filtered, t)
} else {
t.cancel() // cancel any lingering agent goroutine
}
}
if len(filtered) == 0 {
delete(c.chatTasks, chatID)
} else {
c.chatTasks[chatID] = filtered
}
}
}
// handleHealth handles health check requests
func (c *WeComAIBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
status := "ok"
if !c.IsRunning() {
status = "not running"
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": status,
})
}
================================================
FILE: pkg/channels/wecom/aibot_test.go
================================================
package wecom
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
// ---- Webhook mode tests ----
func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) {
t.Run("success with valid config", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
WebhookPath: "/webhook/test",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
// Webhook mode must implement WebhookHandler.
if _, ok := ch.(channels.WebhookHandler); !ok {
t.Error("Webhook mode channel should implement WebhookHandler")
}
})
t.Run("error with missing token", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing token, got nil")
}
})
t.Run("error with missing encoding key", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing encoding key, got nil")
}
})
}
func TestWeComAIBotWebhookChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ctx := context.Background()
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running after Start")
}
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped after Stop")
}
}
func TestWeComAIBotChannelWebhookPath(t *testing.T) {
t.Run("default path", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
wh, ok := ch.(channels.WebhookHandler)
if !ok {
t.Fatal("Expected channel to implement WebhookHandler")
}
expectedPath := "/webhook/wecom-aibot"
if wh.WebhookPath() != expectedPath {
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath())
}
})
t.Run("custom path", func(t *testing.T) {
customPath := "/custom/webhook"
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
WebhookPath: customPath,
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
wh, ok := ch.(channels.WebhookHandler)
if !ok {
t.Fatal("Expected channel to implement WebhookHandler")
}
if wh.WebhookPath() != customPath {
t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath())
}
})
}
func TestWeComAIBotChannelGetStreamResponseProcessingMessage(t *testing.T) {
validAESKey := "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"
t.Run("uses default processing message", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: validAESKey,
}
messageBus := bus.NewMessageBus()
channel, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ch, ok := channel.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
task := &streamTask{
StreamID: "stream-default",
ChatID: "chat-default",
Deadline: time.Now().Add(-time.Second),
}
ch.streamTasks[task.StreamID] = task
ch.chatTasks[task.ChatID] = []*streamTask{task}
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
if !resp.Stream.Finish {
t.Fatal("Expected finished stream response after deadline")
}
if resp.Stream.Content != config.DefaultWeComAIBotProcessingMessage {
t.Fatalf("Expected default processing message %q, got %q",
config.DefaultWeComAIBotProcessingMessage, resp.Stream.Content)
}
if !task.StreamClosed {
t.Fatal("Expected task stream to be marked closed")
}
if _, ok := ch.streamTasks[task.StreamID]; ok {
t.Fatal("Expected closed stream task to be removed from streamTasks")
}
if len(ch.chatTasks[task.ChatID]) != 1 {
t.Fatalf("Expected task to remain queued for response_url delivery, got %d entries",
len(ch.chatTasks[task.ChatID]))
}
})
t.Run("uses custom processing message", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: validAESKey,
ProcessingMessage: "Please wait a moment. The result will be delivered in a follow-up message.",
}
messageBus := bus.NewMessageBus()
channel, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ch, ok := channel.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
task := &streamTask{
StreamID: "stream-custom",
ChatID: "chat-custom",
Deadline: time.Now().Add(-time.Second),
}
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
if resp.Stream.Content != cfg.ProcessingMessage {
t.Fatalf("Expected custom processing message %q, got %q", cfg.ProcessingMessage, resp.Stream.Content)
}
})
}
func TestGenerateStreamID(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
webhookCh, ok := ch.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := webhookCh.generateStreamID()
if len(id) != 10 {
t.Errorf("Expected stream ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate stream ID generated: %s", id)
}
ids[id] = true
}
}
func TestEncryptDecrypt(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
webhookCh, ok := ch.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
plaintext := "Hello, World!"
receiveid := ""
encrypted, err := webhookCh.encryptMessage(plaintext, receiveid)
if err != nil {
t.Fatalf("Failed to encrypt message: %v", err)
}
if encrypted == "" {
t.Fatal("Encrypted message is empty")
}
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid)
if err != nil {
t.Fatalf("Failed to decrypt message: %v", err)
}
if decrypted != plaintext {
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
}
}
func TestGenerateSignature(t *testing.T) {
token := "test_token"
timestamp := "1234567890"
nonce := "test_nonce"
encrypt := "encrypted_msg"
signature := computeSignature(token, timestamp, nonce, encrypt)
if signature == "" {
t.Error("Generated signature is empty")
}
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
t.Error("Generated signature does not verify correctly")
}
}
func decodeStreamResponse(t *testing.T, ch *WeComAIBotChannel, encryptedResponse string) WeComAIBotStreamResponse {
t.Helper()
var wrapped WeComAIBotEncryptedResponse
if err := json.Unmarshal([]byte(encryptedResponse), &wrapped); err != nil {
t.Fatalf("Failed to unmarshal encrypted response: %v", err)
}
plaintext, err := decryptMessageWithVerify(wrapped.Encrypt, ch.config.EncodingAESKey, "")
if err != nil {
t.Fatalf("Failed to decrypt response: %v", err)
}
var resp WeComAIBotStreamResponse
if err := json.Unmarshal([]byte(plaintext), &resp); err != nil {
t.Fatalf("Failed to unmarshal decrypted response: %v", err)
}
return resp
}
// ---- WebSocket long-connection mode tests ----
func TestNewWeComAIBotChannel_WSMode(t *testing.T) {
t.Run("success with bot_id and secret", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
// WebSocket mode must NOT implement WebhookHandler.
if _, ok := ch.(channels.WebhookHandler); ok {
t.Error("WebSocket mode channel should NOT implement WebhookHandler")
}
})
t.Run("ws mode takes priority over webhook fields", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
Token: "also_set",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if _, ok := ch.(*WeComAIBotWSChannel); !ok {
t.Error("Expected WebSocket mode channel when both BotID+Secret and Token+Key are set")
}
})
t.Run("error with missing bot_id", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Secret: "test_secret",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
// Missing bot_id alone means neither WS mode nor webhook mode is fully configured.
if err == nil {
t.Fatal("Expected error for missing bot_id, got nil")
}
})
t.Run("error with missing secret", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing secret, got nil")
}
})
}
func TestWeComAIBotWSChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ctx := context.Background()
// Start launches a background goroutine; it should not block or return an error.
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running after Start")
}
// Stop should work regardless of whether the WebSocket actually connected.
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped after Stop")
}
}
func TestGenerateRandomID(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 200; i++ {
id := generateRandomID(10)
if len(id) != 10 {
t.Errorf("Expected ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate ID generated: %s", id)
}
ids[id] = true
}
}
func TestWSGenerateID(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 200; i++ {
id := wsGenerateID()
if len(id) != 10 {
t.Errorf("Expected ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate wsGenerateID result: %s", id)
}
ids[id] = true
}
}
// ---- Webhook streaming fallback tests ----
// makeWebhookChannel creates a started WeComAIBotChannel for testing.
func makeWebhookChannel(t *testing.T) *WeComAIBotChannel {
t.Helper()
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG",
}
ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("create channel: %v", err)
}
wc := ch.(*WeComAIBotChannel)
wc.ctx, wc.cancel = context.WithCancel(context.Background())
return wc
}
// makeStreamTask creates and registers a streamTask for testing.
func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask {
t.Helper()
task := &streamTask{
StreamID: streamID,
ChatID: chatID,
Deadline: deadline,
answerCh: make(chan string, 1),
}
task.ctx, task.cancel = context.WithCancel(ch.ctx)
ch.taskMu.Lock()
ch.streamTasks[streamID] = task
ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task)
ch.taskMu.Unlock()
return task
}
// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already
// placed its answer in answerCh, getStreamResponse returns a finish=true response
// and fully removes the task.
func TestGetStreamResponse_ImmediateAnswer(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second))
task.answerCh <- "hello from agent"
result := ch.getStreamResponse(task, "ts123", "nonce123")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, exists := ch.streamTasks["stream-1"]
ch.taskMu.RUnlock()
if exists {
t.Error("task should have been removed from streamTasks after normal finish")
}
if !task.Finished {
t.Error("task.Finished should be true after normal finish")
}
}
// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has
// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the
// task alive so the response_url fallback can still deliver the answer.
func TestGetStreamResponse_DeadlinePassed(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond))
result := ch.getStreamResponse(task, "ts456", "nonce456")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, stillStreaming := ch.streamTasks["stream-2"]
ch.taskMu.RUnlock()
if stillStreaming {
t.Error("task should have been removed from streamTasks after deadline")
}
if !task.StreamClosed {
t.Error("task.StreamClosed should be true after deadline")
}
if task.Finished {
t.Error("task.Finished must remain false: agent reply still expected via response_url")
}
}
// TestGetStreamResponse_StillPending verifies that when neither the agent has
// replied nor the deadline has passed, getStreamResponse returns without altering
// task state (client should poll again).
func TestGetStreamResponse_StillPending(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second))
result := ch.getStreamResponse(task, "ts789", "nonce789")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, exists := ch.streamTasks["stream-3"]
ch.taskMu.RUnlock()
if !exists {
t.Error("pending task should still be in streamTasks")
}
if task.Finished || task.StreamClosed {
t.Error("pending task should not be finished or stream-closed")
}
// Cleanup.
ch.removeTask(task)
}
================================================
FILE: pkg/channels/wecom/aibot_ws.go
================================================
package wecom
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
)
// Long-connection WebSocket endpoint.
// Ref: https://developer.work.weixin.qq.com/document/path/101463
const (
wsEndpoint = "wss://openws.work.weixin.qq.com"
wsHeartbeatInterval = 30 * time.Second
wsConnectTimeout = 15 * time.Second
wsSubscribeTimeout = 10 * time.Second
wsSendMsgTimeout = 10 * time.Second
wsRespondMsgTimeout = 10 * time.Second
wsWelcomeMsgTimeout = 5 * time.Second // WeCom requires welcome reply within 5 seconds
wsMaxReconnectWait = 60 * time.Second
wsInitialReconnect = time.Second
// WeCom requires finish=true within 6 minutes of the first stream frame.
// wsStreamTickInterval controls how often we send an in-progress hint.
// wsStreamMaxDuration is a safety margin below the 6-minute hard limit.
wsStreamTickInterval = 30 * time.Second
wsStreamMaxDuration = 5*time.Minute + 30*time.Second
// wsImageDownloadTimeout caps the time we spend downloading an inbound image.
wsImageDownloadTimeout = 30 * time.Second
// Keep req_id -> chat route for late fallback pushes after stream window closes.
wsLateReplyRouteTTL = 30 * time.Minute
// wsStreamMaxContentBytes is the maximum UTF-8 byte length for the content field
// of a single WeCom AI Bot stream / text / markdown frame.
// Ref: https://developer.work.weixin.qq.com/document/path/101463
wsStreamMaxContentBytes = 20480
)
// wsImageHTTPClient is a shared HTTP client for downloading inbound images.
// Reusing it enables connection pooling across multiple image downloads.
var wsImageHTTPClient = &http.Client{Timeout: wsImageDownloadTimeout}
// WeComAIBotWSChannel implements channels.Channel for WeCom AI Bot using the
// WebSocket long-connection API.
// Unlike the webhook counterpart it does NOT implement WebhookHandler, so the
// HTTP manager will not register any callback URL for it.
type WeComAIBotWSChannel struct {
*channels.BaseChannel
config config.WeComAIBotConfig
ctx context.Context
cancel context.CancelFunc
// conn is the active WebSocket connection; nil when disconnected.
// All writes are serialized through connMu.
conn *websocket.Conn
connMu sync.Mutex
// dedupe prevents duplicate message processing (WeCom may re-deliver).
dedupe *MessageDeduplicator
// reqStates holds per-req_id runtime state.
// It unifies active task state and late-reply fallback routing.
reqStates map[string]*wsReqState
reqStatesMu sync.Mutex
// reqPending correlates command req_ids with response channels.
// Used only for subscribe/ping command-response pairs.
reqPending map[string]chan wsEnvelope
reqPendingMu sync.Mutex
}
// wsTask tracks one in-progress agent reply for a single chat turn.
type wsTask struct {
ReqID string // req_id echoed in all replies for this turn
ChatID string
ChatType uint32
StreamID string // our generated stream.id
answerCh chan string // agent delivers its reply here via Send()
ctx context.Context
cancel context.CancelFunc
}
type wsReqState struct {
Task *wsTask
Route wsLateReplyRoute
}
type wsLateReplyRoute struct {
ChatID string
ChatType uint32
ReadyAt time.Time
ExpiresAt time.Time
}
// ---- WebSocket protocol types ----
// wsEnvelope is the generic JSON envelope for all WebSocket messages.
type wsEnvelope struct {
Cmd string `json:"cmd,omitempty"`
Headers wsHeaders `json:"headers"`
Body json.RawMessage `json:"body,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
type wsHeaders struct {
ReqID string `json:"req_id"`
}
// wsCommand is an outgoing request sent over the WebSocket.
type wsCommand struct {
Cmd string `json:"cmd"`
Headers wsHeaders `json:"headers"`
Body any `json:"body,omitempty"`
}
type wsSendMsgBody struct {
ChatID string `json:"chatid"`
ChatType uint32 `json:"chat_type,omitempty"`
MsgType string `json:"msgtype"`
Markdown *wsMarkdownContent `json:"markdown,omitempty"`
}
// wsRespondMsgBody is the body for aibot_respond_msg / aibot_respond_welcome_msg.
type wsRespondMsgBody struct {
MsgType string `json:"msgtype"`
Stream *wsStreamContent `json:"stream,omitempty"`
Text *wsTextContent `json:"text,omitempty"`
Markdown *wsMarkdownContent `json:"markdown,omitempty"`
Image *wsImageContent `json:"image,omitempty"`
}
type wsStreamContent struct {
ID string `json:"id"`
Finish bool `json:"finish"`
Content string `json:"content,omitempty"`
}
// wsImageContent carries a base64-encoded image payload for outbound messages.
type wsImageContent struct {
Base64 string `json:"base64"`
MD5 string `json:"md5"`
}
type wsTextContent struct {
Content string `json:"content"`
}
type wsMarkdownContent struct {
Content string `json:"content"`
}
// WeComAIBotWSMessage is the decoded body of aibot_msg_callback /
// aibot_event_callback in WebSocket long-connection mode.
// The structure mirrors WeComAIBotMessage but includes extra fields
// that only appear in long-connection callbacks (Voice, AESKey on Image/File).
type WeComAIBotWSMessage struct {
MsgID string `json:"msgid"`
CreateTime int64 `json:"create_time,omitempty"`
AIBotID string `json:"aibotid"`
ChatID string `json:"chatid,omitempty"`
ChatType string `json:"chattype,omitempty"` // "single" | "group"
From struct {
UserID string `json:"userid"`
} `json:"from"`
MsgType string `json:"msgtype"`
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
Image *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"` // long-connection: per-resource decrypt key
} `json:"image,omitempty"`
Voice *struct {
Content string `json:"content"` // WeCom transcribes voice to text in callbacks
} `json:"voice,omitempty"`
Mixed *struct {
MsgItem []struct {
MsgType string `json:"msgtype"`
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
Image *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"image,omitempty"`
} `json:"msg_item"`
} `json:"mixed,omitempty"`
Event *struct {
EventType string `json:"eventtype"`
} `json:"event,omitempty"`
File *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"file,omitempty"`
Video *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"video,omitempty"`
}
// ---- Constructor ----
// newWeComAIBotWSChannel creates a WeComAIBotWSChannel for WebSocket mode.
func newWeComAIBotWSChannel(
cfg config.WeComAIBotConfig,
messageBus *bus.MessageBus,
) (*WeComAIBotWSChannel, error) {
if cfg.BotID == "" || cfg.Secret == "" {
return nil, fmt.Errorf("bot_id and secret are required for WeCom AI Bot WebSocket mode")
}
base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom,
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
return &WeComAIBotWSChannel{
BaseChannel: base,
config: cfg,
dedupe: NewMessageDeduplicator(wecomMaxProcessedMessages),
reqStates: make(map[string]*wsReqState),
reqPending: make(map[string]chan wsEnvelope),
}, nil
}
// ---- Channel interface ----
// Name implements channels.Channel.
func (c *WeComAIBotWSChannel) Name() string { return "wecom_aibot" }
// Start connects to the WeCom WebSocket endpoint and begins message processing.
func (c *WeComAIBotWSChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel (WebSocket long-connection mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
go c.connectLoop()
logger.InfoC("wecom_aibot", "WeCom AI Bot channel started (WebSocket mode)")
return nil
}
// Stop shuts down the channel and closes the WebSocket connection.
func (c *WeComAIBotWSChannel) Stop(_ context.Context) error {
logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel (WebSocket mode)...")
if c.cancel != nil {
c.cancel()
}
c.connMu.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connMu.Unlock()
c.SetRunning(false)
logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped")
return nil
}
// Send delivers the agent reply for msg.ChatID.
// The waiting task goroutine picks it up and writes the final stream response.
func (c *WeComAIBotWSChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
// msg.ChatID carries the inbound req_id (set by dispatchWSAgentTask).
// For cron-triggered messages, msg.ChatID is the real WeCom chat/user ID
// and there will be no matching entry in reqStates; fall through to proactive push.
task, route, ok := c.getReqState(msg.ChatID)
if !ok {
// No req_id record found — this is a cron/scheduler-originated message.
// Send it as a proactive markdown push using the chat ID directly.
logger.InfoCF("wecom_aibot", "Send: no req_id state, delivering via proactive push (cron/scheduler)",
map[string]any{"chat_id": msg.ChatID})
if err := c.wsSendActivePush(msg.ChatID, 0, msg.Content); err != nil {
logger.WarnCF("wecom_aibot", "Proactive push failed",
map[string]any{"chat_id": msg.ChatID, "error": err.Error()})
return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed)
}
return nil
}
if task == nil {
if time.Now().Before(route.ReadyAt) {
// Keep using aibot_respond_msg within stream window; do not proactively
// push unless wsStreamMaxDuration has elapsed.
logger.WarnCF("wecom_aibot", "Send: stream window still open, skip proactive push",
map[string]any{"req_id": msg.ChatID, "ready_at": route.ReadyAt.Format(time.RFC3339)})
return nil
}
if err := c.wsSendActivePush(route.ChatID, route.ChatType, msg.Content); err != nil {
logger.WarnCF("wecom_aibot", "Late reply proactive push failed",
map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "error": err.Error()})
return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed)
}
logger.InfoCF("wecom_aibot", "Late reply delivered via proactive push",
map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "chat_type": route.ChatType})
c.deleteReqState(msg.ChatID)
return nil
}
// Non-blocking fast path: when answerCh has space, deliver without racing
// against task.ctx.Done() (which fires when the task is canceled by a new
// incoming message, but the response must still be sent).
select {
case task.answerCh <- msg.Content:
return nil
default:
}
// answerCh was full; block with cancellation guards.
select {
case task.answerCh <- msg.Content:
case <-task.ctx.Done():
return nil
case <-ctx.Done():
return ctx.Err()
}
return nil
}
// ---- Connection management ----
// wsBackoffResetDuration is the minimum duration a WebSocket connection must
// stay up before we reset the reconnect backoff to its initial value. This
// prevents a short burst of failures from causing long waits after later,
// stable connection periods.
const wsBackoffResetDuration = time.Minute
// connectLoop maintains the WebSocket connection, reconnecting on failure with
// exponential backoff.
func (c *WeComAIBotWSChannel) connectLoop() {
backoff := wsInitialReconnect
for {
select {
case <-c.ctx.Done():
return
default:
}
logger.InfoC("wecom_aibot", "Connecting to WeCom WebSocket endpoint...")
start := time.Now()
if err := c.runConnection(); err != nil {
elapsed := time.Since(start)
// If the connection was stable for long enough, reset backoff so that
// a previous burst of failures does not keep us at the maximum delay.
if elapsed >= wsBackoffResetDuration {
backoff = wsInitialReconnect
}
select {
case <-c.ctx.Done():
return
default:
logger.WarnCF("wecom_aibot", "WebSocket connection lost, reconnecting",
map[string]any{"error": err.Error(), "backoff": backoff.String()})
select {
case <-time.After(backoff):
case <-c.ctx.Done():
return
}
if backoff < wsMaxReconnectWait {
backoff *= 2
if backoff > wsMaxReconnectWait {
backoff = wsMaxReconnectWait
}
}
}
} else {
// Clean exit (context canceled); stop reconnecting.
return
}
}
}
// runConnection dials, subscribes, and runs the read/heartbeat loops until the
// connection closes or the channel context is canceled.
func (c *WeComAIBotWSChannel) runConnection() error {
dialCtx, dialCancel := context.WithTimeout(c.ctx, wsConnectTimeout)
conn, httpResp, err := websocket.DefaultDialer.DialContext(dialCtx, wsEndpoint, nil)
dialCancel()
if httpResp != nil {
httpResp.Body.Close()
}
if err != nil {
return fmt.Errorf("dial failed: %w", err)
}
c.connMu.Lock()
c.conn = conn
c.connMu.Unlock()
defer func() {
c.connMu.Lock()
if c.conn == conn {
c.conn = nil
}
c.connMu.Unlock()
// Cancel any tasks that were started over this connection so their
// agent goroutines do not keep running after the connection is gone.
c.cancelAllTasks()
}()
// ---- Read loop (must start BEFORE subscribing) ----
// sendAndWait blocks waiting for the subscribe response on reqPending;
// readLoop is the only goroutine that delivers messages to reqPending.
// Starting readLoop first avoids a deadlock where sendAndWait times out
// because no one reads the server's reply.
readErrCh := make(chan error, 1)
go func() { readErrCh <- c.readLoop(conn) }()
// ---- Subscribe ----
reqID := wsGenerateID()
resp, err := c.sendAndWait(conn, reqID, wsCommand{
Cmd: "aibot_subscribe",
Headers: wsHeaders{ReqID: reqID},
Body: map[string]string{
"bot_id": c.config.BotID,
"secret": c.config.Secret,
},
}, wsSubscribeTimeout)
if err != nil {
conn.Close() // stop readLoop
<-readErrCh
return fmt.Errorf("subscribe failed: %w", err)
}
if resp.ErrCode != 0 {
conn.Close()
<-readErrCh
return fmt.Errorf("subscribe rejected (errcode=%d): %s", resp.ErrCode, resp.ErrMsg)
}
logger.InfoC("wecom_aibot", "WebSocket subscription successful")
// ---- Heartbeat goroutine ----
hbDone := make(chan struct{})
go func() {
defer close(hbDone)
c.heartbeatLoop(conn)
}()
// Wait for the read loop to exit, then tear down the heartbeat.
readErr := <-readErrCh
conn.Close() // signal heartbeat to stop (idempotent)
<-hbDone
return readErr
}
// sendAndWait registers a pending-response slot, sends cmd, and blocks until
// the matching response arrives or the timeout/context fires.
func (c *WeComAIBotWSChannel) sendAndWait(
conn *websocket.Conn,
reqID string,
cmd wsCommand,
timeout time.Duration,
) (wsEnvelope, error) {
ch := make(chan wsEnvelope, 1)
c.reqPendingMu.Lock()
c.reqPending[reqID] = ch
c.reqPendingMu.Unlock()
cleanup := func() {
c.reqPendingMu.Lock()
delete(c.reqPending, reqID)
c.reqPendingMu.Unlock()
}
data, err := json.Marshal(cmd)
if err != nil {
cleanup()
return wsEnvelope{}, fmt.Errorf("marshal command: %w", err)
}
c.connMu.Lock()
err = conn.WriteMessage(websocket.TextMessage, data)
c.connMu.Unlock()
if err != nil {
cleanup()
return wsEnvelope{}, fmt.Errorf("write command: %w", err)
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case env := <-ch:
return env, nil
case <-timer.C:
cleanup()
return wsEnvelope{}, fmt.Errorf("timeout waiting for response (req_id=%s)", reqID)
case <-c.ctx.Done():
cleanup()
return wsEnvelope{}, c.ctx.Err()
}
}
// heartbeatLoop sends a ping every wsHeartbeatInterval until conn is closed.
// It validates the server's pong response via sendAndWait; a failed pong
// triggers a reconnection by closing the connection.
func (c *WeComAIBotWSChannel) heartbeatLoop(conn *websocket.Conn) {
ticker := time.NewTicker(wsHeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
reqID := wsGenerateID()
resp, err := c.sendAndWait(conn, reqID, wsCommand{
Cmd: "ping",
Headers: wsHeaders{ReqID: reqID},
}, wsHeartbeatInterval)
if err != nil {
logger.WarnCF("wecom_aibot", "Heartbeat failed, closing connection",
map[string]any{"error": err.Error()})
conn.Close()
return
}
if resp.ErrCode != 0 {
logger.WarnCF("wecom_aibot", "Heartbeat rejected",
map[string]any{"errcode": resp.ErrCode, "errmsg": resp.ErrMsg})
conn.Close()
return
}
logger.DebugCF("wecom_aibot", "Heartbeat pong received", map[string]any{"req_id": reqID})
case <-c.ctx.Done():
return
}
}
}
// readLoop reads WebSocket messages and dispatches them until the connection
// closes or the channel is stopped.
func (c *WeComAIBotWSChannel) readLoop(conn *websocket.Conn) error {
for {
_, raw, err := conn.ReadMessage()
if err != nil {
select {
case <-c.ctx.Done():
return nil // clean shutdown
default:
return fmt.Errorf("read error: %w", err)
}
}
var env wsEnvelope
if err := json.Unmarshal(raw, &env); err != nil {
logger.WarnCF("wecom_aibot", "Failed to parse WebSocket message",
map[string]any{"error": err.Error(), "raw": string(raw)})
continue
}
// Command responses have an empty Cmd field; forward to any waiting
// sendAndWait() call, or silently drop if no one is waiting (e.g.
// late responses after timeout).
if env.Cmd == "" && env.Headers.ReqID != "" {
c.reqPendingMu.Lock()
ch, ok := c.reqPending[env.Headers.ReqID]
if ok {
delete(c.reqPending, env.Headers.ReqID)
}
c.reqPendingMu.Unlock()
if ok {
ch <- env
}
continue
}
// Dispatch to appropriate handler in a separate goroutine so the
// read loop is never blocked by a slow agent.
go c.handleEnvelope(env)
}
}
// ---- Message / event handlers ----
// handleEnvelope routes a WebSocket envelope to the right handler.
func (c *WeComAIBotWSChannel) handleEnvelope(env wsEnvelope) {
switch env.Cmd {
case "aibot_msg_callback":
c.handleMsgCallback(env)
case "aibot_event_callback":
c.handleEventCallback(env)
default:
logger.DebugCF("wecom_aibot", "Unhandled WebSocket command",
map[string]any{"cmd": env.Cmd})
}
}
// handleMsgCallback processes aibot_msg_callback.
func (c *WeComAIBotWSChannel) handleMsgCallback(env wsEnvelope) {
var msg WeComAIBotWSMessage
if err := json.Unmarshal(env.Body, &msg); err != nil {
logger.WarnCF("wecom_aibot", "Failed to parse msg callback body",
map[string]any{"error": err.Error()})
return
}
// Deduplicate by msgid (WeCom may re-deliver on network issues).
if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) {
logger.DebugCF("wecom_aibot", "Duplicate message ignored",
map[string]any{"msgid": msg.MsgID})
return
}
reqID := env.Headers.ReqID
switch msg.MsgType {
case "text":
c.handleWSTextMessage(reqID, msg)
case "image":
c.handleWSImageMessage(reqID, msg)
case "voice":
c.handleWSVoiceMessage(reqID, msg)
case "mixed":
c.handleWSMixedMessage(reqID, msg)
case "file":
c.handleWSFileMessage(reqID, msg)
case "video":
c.handleWSVideoMessage(reqID, msg)
default:
logger.WarnCF("wecom_aibot", "Unsupported message type",
map[string]any{"msgtype": msg.MsgType})
c.wsSendStreamFinish(reqID, wsGenerateID(),
"Unsupported message type: "+msg.MsgType)
}
}
// handleEventCallback processes aibot_event_callback.
func (c *WeComAIBotWSChannel) handleEventCallback(env wsEnvelope) {
var msg WeComAIBotWSMessage
if err := json.Unmarshal(env.Body, &msg); err != nil {
logger.WarnCF("wecom_aibot", "Failed to parse event callback body",
map[string]any{"error": err.Error()})
return
}
// Deduplicate by msgid.
if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) {
logger.DebugCF("wecom_aibot", "Duplicate event ignored",
map[string]any{"msgid": msg.MsgID})
return
}
var eventType string
if msg.Event != nil {
eventType = msg.Event.EventType
}
logger.DebugCF("wecom_aibot", "Received event callback",
map[string]any{"event_type": eventType})
switch eventType {
case "enter_chat":
if c.config.WelcomeMessage != "" {
c.wsSendWelcomeMsg(env.Headers.ReqID, c.config.WelcomeMessage)
}
case "disconnected_event":
// The server will close this connection after sending this event.
// connectLoop will detect the closure and reconnect automatically.
logger.WarnC("wecom_aibot",
"Received disconnected_event: this connection is being replaced by a newer one")
default:
logger.DebugCF("wecom_aibot", "Unhandled event type",
map[string]any{"event_type": eventType})
}
}
// handleWSTextMessage dispatches a plain-text message to the agent and streams
// the reply back over the WebSocket connection.
func (c *WeComAIBotWSChannel) handleWSTextMessage(reqID string, msg WeComAIBotWSMessage) {
if msg.Text == nil {
logger.ErrorC("wecom_aibot", "text message missing text field")
return
}
c.dispatchWSAgentTask(reqID, msg, msg.Text.Content, nil)
}
// handleWSImageMessage downloads and stores the inbound image, then dispatches
// it to the agent as a media-tagged message.
func (c *WeComAIBotWSChannel) handleWSImageMessage(reqID string, msg WeComAIBotWSMessage) {
if msg.Image == nil {
logger.WarnC("wecom_aibot", "Image message missing image field")
c.wsSendStreamFinish(reqID, wsGenerateID(), "Image message could not be processed.")
return
}
c.wsHandleMediaMessage(reqID, msg, msg.Image.URL, msg.Image.AESKey, "image")
}
// wsHandleMediaMessage is a shared helper for image, file and video messages.
// It downloads the resource, stores it in MediaStore, and dispatches to the agent.
func (c *WeComAIBotWSChannel) wsHandleMediaMessage(
reqID string, msg WeComAIBotWSMessage,
resourceURL, aesKey, label string,
) {
chatID := wsChatID(msg)
ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout)
defer cancel()
ref, err := c.storeWSMedia(ctx, chatID, msg.MsgID, resourceURL, aesKey, wsLabelToDefaultExt(label))
if err != nil {
logger.WarnCF("wecom_aibot", "Failed to download/store WS "+label,
map[string]any{"error": err.Error(), "url": resourceURL})
c.wsSendStreamFinish(reqID, wsGenerateID(),
strings.ToUpper(label[:1])+label[1:]+" message could not be processed.")
return
}
c.dispatchWSAgentTask(reqID, msg, "["+label+"]", []string{ref})
}
// handleWSMixedMessage handles mixed text+image messages.
// All text parts are collected into the content string; all image parts are
// downloaded and stored in MediaStore before dispatching to the agent.
func (c *WeComAIBotWSChannel) handleWSMixedMessage(reqID string, msg WeComAIBotWSMessage) {
if msg.Mixed == nil {
logger.WarnC("wecom_aibot", "Mixed message has no content")
c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.")
return
}
chatID := wsChatID(msg)
ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout)
defer cancel()
var textParts []string
var mediaRefs []string
for _, item := range msg.Mixed.MsgItem {
switch item.MsgType {
case "text":
if item.Text != nil && item.Text.Content != "" {
textParts = append(textParts, item.Text.Content)
}
case "image":
if item.Image != nil {
ref, err := c.storeWSMedia(ctx, chatID,
msg.MsgID+"-"+wsGenerateID(), item.Image.URL, item.Image.AESKey, ".jpg")
if err != nil {
logger.WarnCF("wecom_aibot", "Failed to download/store mixed image",
map[string]any{"error": err.Error()})
} else {
mediaRefs = append(mediaRefs, ref)
}
}
default:
logger.WarnCF("wecom_aibot", "Unsupported item type in mixed message",
map[string]any{"msgtype": item.MsgType})
}
}
if len(textParts) == 0 && len(mediaRefs) == 0 {
logger.WarnC("wecom_aibot", "Mixed message has no usable content")
c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.")
return
}
content := strings.Join(textParts, "\n")
if content == "" {
content = "[images]"
}
c.dispatchWSAgentTask(reqID, msg, content, mediaRefs)
}
// dispatchWSAgentTask registers a new agent task, sends the opening stream frame,
// and starts a goroutine that runs the agent and streams the reply back.
// content is the text forwarded to the agent; mediaRefs are optional media
// store references attached to the inbound message.
func (c *WeComAIBotWSChannel) dispatchWSAgentTask(
reqID string,
msg WeComAIBotWSMessage,
content string,
mediaRefs []string,
) {
userID := msg.From.UserID
if userID == "" {
userID = "unknown"
}
// actualChatID is the real WeCom chat/user ID used for peer identification.
// reqID is used as the routing chatID so each turn is independently addressable.
actualChatID := wsChatID(msg)
streamID := wsGenerateID()
chatType := wsChatTypeValue(msg.ChatType)
taskCtx, taskCancel := context.WithCancel(c.ctx)
task := &wsTask{
ReqID: reqID,
ChatID: actualChatID,
ChatType: chatType,
StreamID: streamID,
answerCh: make(chan string, 1),
ctx: taskCtx,
cancel: taskCancel,
}
// Each req_id is unique per WeCom turn; tasks run concurrently, no cancellation.
c.setReqState(reqID, &wsReqState{
Task: task,
Route: wsLateReplyRoute{
ChatID: actualChatID,
ChatType: chatType,
ReadyAt: time.Now().Add(wsStreamMaxDuration),
ExpiresAt: time.Now().Add(wsLateReplyRouteTTL),
},
})
logger.DebugCF("wecom_aibot", "Registered new agent task",
map[string]any{"chat_id": actualChatID, "req_id": reqID, "stream_id": streamID})
// Send an empty stream opening frame (finish=false) immediately.
c.wsSendStreamChunk(reqID, streamID, false, "")
go func() {
defer func() {
taskCancel()
c.clearReqTask(reqID, task)
}()
sender := bus.SenderInfo{
Platform: "wecom_aibot",
PlatformID: userID,
CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID),
DisplayName: userID,
}
peerKind := "direct"
if msg.ChatType == "group" {
peerKind = "group"
}
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
metadata := map[string]string{
"channel": "wecom_aibot",
"chat_id": actualChatID,
"chat_type": msg.ChatType,
"msg_type": msg.MsgType,
"msgid": msg.MsgID,
"aibotid": msg.AIBotID,
"stream_id": streamID,
}
// Pass reqID as chatID: OutboundMessage.ChatID = reqID → Send() finds tasks[reqID].
c.HandleMessage(taskCtx, peer, reqID, userID, reqID,
content, mediaRefs, metadata, sender)
// Wait for the agent reply. While waiting, send periodic finish=false
// hints so the user knows processing is still in progress.
// WeCom requires finish=true within 6 minutes of the first stream frame;
// wsStreamMaxDuration enforces that limit with a safety margin.
waitHints := []string{
"⏳ Processing, please wait...",
"⏳ Still processing, please wait...",
"⏳ Almost there, please wait...",
}
ticker := time.NewTicker(wsStreamTickInterval)
defer ticker.Stop()
deadlineTimer := time.NewTimer(wsStreamMaxDuration)
defer deadlineTimer.Stop()
tickCount := 0
for {
select {
case answer := <-task.answerCh:
// Split the answer into byte-bounded chunks and send as stream frames.
// All but the last carry finish=false; the final frame closes the stream.
chunks := splitWSContent(answer, wsStreamMaxContentBytes)
for i, chunk := range chunks {
c.wsSendStreamChunk(reqID, streamID, i == len(chunks)-1, chunk)
}
c.deleteReqState(reqID)
return
case <-ticker.C:
hint := waitHints[tickCount%len(waitHints)]
tickCount++
logger.DebugCF("wecom_aibot", "Sending stream progress hint",
map[string]any{"chat_id": actualChatID, "tick": tickCount})
c.wsSendStreamChunk(reqID, streamID, false, hint)
case <-deadlineTimer.C:
logger.WarnCF("wecom_aibot",
"Stream response deadline reached, closing stream; late reply will be pushed",
map[string]any{"chat_id": actualChatID})
c.wsSendStreamFinish(reqID, streamID,
"⏳ Processing is taking longer than expected, the response will be sent as a follow-up message.")
return
case <-taskCtx.Done():
// Give a short grace period so that a response queued in the bus
// just before cancellation can still be delivered. This closes a
// race where a rapid second message cancels this task after the
// agent already published but before Send() wrote to answerCh.
//
// The connection is gone at this point, so we cannot use
// wsSendStreamFinish. Try wsSendActivePush on the (possibly
// already-restored) connection; if that also fails, leave the
// route intact so Send() can push the reply once reconnected.
select {
case answer := <-task.answerCh:
if err := c.wsSendActivePush(task.ChatID, task.ChatType, answer); err != nil {
logger.WarnCF("wecom_aibot",
"Grace-period push failed after task cancellation; reply may be lost",
map[string]any{"req_id": reqID, "chat_id": task.ChatID, "error": err.Error()})
} else {
c.deleteReqState(reqID)
}
case <-time.After(100 * time.Millisecond):
}
return
}
}
}()
}
// handleWSVoiceMessage handles voice messages.
// WeCom transcribes voice to text in the callback; if the transcription is
// present it is dispatched as plain text to the agent.
func (c *WeComAIBotWSChannel) handleWSVoiceMessage(reqID string, msg WeComAIBotWSMessage) {
if msg.Voice != nil && msg.Voice.Content != "" {
c.dispatchWSAgentTask(reqID, msg, msg.Voice.Content, nil)
return
}
c.wsSendStreamFinish(reqID, wsGenerateID(), "Voice messages are not yet supported.")
}
// handleWSFileMessage handles file messages.
func (c *WeComAIBotWSChannel) handleWSFileMessage(reqID string, msg WeComAIBotWSMessage) {
if msg.File == nil {
logger.WarnC("wecom_aibot", "File message missing file field")
c.wsSendStreamFinish(reqID, wsGenerateID(), "File message could not be processed.")
return
}
c.wsHandleMediaMessage(reqID, msg, msg.File.URL, msg.File.AESKey, "file")
}
// handleWSVideoMessage handles video messages.
func (c *WeComAIBotWSChannel) handleWSVideoMessage(reqID string, msg WeComAIBotWSMessage) {
if msg.Video == nil {
logger.WarnC("wecom_aibot", "Video message missing video field")
c.wsSendStreamFinish(reqID, wsGenerateID(), "Video message could not be processed.")
return
}
c.wsHandleMediaMessage(reqID, msg, msg.Video.URL, msg.Video.AESKey, "video")
}
// ---- WebSocket write helpers ----
// wsSendStreamChunk sends an aibot_respond_msg stream frame.
func (c *WeComAIBotWSChannel) wsSendStreamChunk(reqID, streamID string, finish bool, content string) {
logger.DebugCF("wecom_aibot", "Sending stream chunk", map[string]any{
"stream_id": streamID,
"finish": finish,
"preview": utils.Truncate(content, 100),
})
cmd := wsCommand{
Cmd: "aibot_respond_msg",
Headers: wsHeaders{ReqID: reqID},
Body: wsRespondMsgBody{
MsgType: "stream",
Stream: &wsStreamContent{
ID: streamID,
Finish: finish,
Content: content,
},
},
}
if err := c.writeWSAndWait(cmd, wsRespondMsgTimeout); err != nil {
logger.WarnCF("wecom_aibot", "Stream chunk ack failed", map[string]any{
"req_id": reqID,
"stream_id": streamID,
"finish": finish,
"error": err,
})
}
}
// wsSendStreamFinish sends the final aibot_respond_msg frame (finish=true, no images).
func (c *WeComAIBotWSChannel) wsSendStreamFinish(reqID, streamID, content string) {
c.wsSendStreamChunk(reqID, streamID, true, content)
}
// wsSendWelcomeMsg sends a text welcome message via aibot_respond_welcome_msg.
func (c *WeComAIBotWSChannel) wsSendWelcomeMsg(reqID, content string) {
logger.DebugCF("wecom_aibot", "Sending welcome message", map[string]any{"req_id": reqID})
cmd := wsCommand{
Cmd: "aibot_respond_welcome_msg",
Headers: wsHeaders{ReqID: reqID},
Body: wsRespondMsgBody{
MsgType: "text",
Text: &wsTextContent{Content: content},
},
}
if err := c.writeWSAndWait(cmd, wsWelcomeMsgTimeout); err != nil {
logger.WarnCF("wecom_aibot", "Welcome message ack failed",
map[string]any{"req_id": reqID, "error": err.Error()})
}
}
// wsSendActivePush sends a proactive markdown message using aibot_send_msg.
// Long content is automatically split into byte-bounded chunks (≤ wsStreamMaxContentBytes
// each) and delivered as consecutive messages.
// It is used as a fallback for late replies after stream response window expires.
func (c *WeComAIBotWSChannel) wsSendActivePush(chatID string, chatType uint32, content string) error {
if chatID == "" {
return fmt.Errorf("chatid is empty")
}
for _, chunk := range splitWSContent(content, wsStreamMaxContentBytes) {
reqID := wsGenerateID()
if err := c.writeWSAndWait(wsCommand{
Cmd: "aibot_send_msg",
Headers: wsHeaders{ReqID: reqID},
Body: wsSendMsgBody{
ChatID: chatID,
ChatType: chatType,
MsgType: "markdown",
Markdown: &wsMarkdownContent{Content: chunk},
},
}, wsSendMsgTimeout); err != nil {
return err
}
}
return nil
}
// writeWSAndWait writes cmd to the active connection and validates the command response.
func (c *WeComAIBotWSChannel) writeWSAndWait(cmd wsCommand, timeout time.Duration) error {
if cmd.Headers.ReqID == "" {
return fmt.Errorf("req_id is empty")
}
c.connMu.Lock()
conn := c.conn
c.connMu.Unlock()
if conn == nil {
return fmt.Errorf("websocket not connected")
}
resp, err := c.sendAndWait(conn, cmd.Headers.ReqID, cmd, timeout)
if err != nil {
return err
}
if resp.ErrCode != 0 {
return fmt.Errorf("%s rejected (errcode=%d): %s", cmd.Cmd, resp.ErrCode, resp.ErrMsg)
}
return nil
}
// cancelAllTasks cancels every pending agent task; called when the connection drops.
// It also expires each task's stream window (ReadyAt = now) so that when the agent
// eventually delivers its reply via Send(), the message is forwarded via
// wsSendActivePush on the restored connection instead of being silently discarded.
func (c *WeComAIBotWSChannel) cancelAllTasks() {
c.reqStatesMu.Lock()
defer c.reqStatesMu.Unlock()
now := time.Now()
for _, state := range c.reqStates {
if state != nil && state.Task != nil {
state.Task.cancel()
state.Task = nil
// Expire the stream window immediately so Send() uses wsSendActivePush.
state.Route.ReadyAt = now
}
}
}
func (c *WeComAIBotWSChannel) setReqState(reqID string, state *wsReqState) {
c.reqStatesMu.Lock()
defer c.reqStatesMu.Unlock()
now := time.Now()
for k, v := range c.reqStates {
if v == nil || now.After(v.Route.ExpiresAt) {
delete(c.reqStates, k)
}
}
c.reqStates[reqID] = state
}
func (c *WeComAIBotWSChannel) getReqState(reqID string) (*wsTask, wsLateReplyRoute, bool) {
c.reqStatesMu.Lock()
defer c.reqStatesMu.Unlock()
state, ok := c.reqStates[reqID]
if !ok || state == nil {
return nil, wsLateReplyRoute{}, false
}
if time.Now().After(state.Route.ExpiresAt) {
delete(c.reqStates, reqID)
return nil, wsLateReplyRoute{}, false
}
return state.Task, state.Route, true
}
func (c *WeComAIBotWSChannel) deleteReqState(reqID string) {
c.reqStatesMu.Lock()
delete(c.reqStates, reqID)
c.reqStatesMu.Unlock()
}
func (c *WeComAIBotWSChannel) clearReqTask(reqID string, task *wsTask) {
c.reqStatesMu.Lock()
defer c.reqStatesMu.Unlock()
state, ok := c.reqStates[reqID]
if !ok || state == nil {
return
}
if state.Task == task {
state.Task = nil
}
}
func wsChatTypeValue(chatType string) uint32 {
if chatType == "group" {
return 2
}
return 1
}
// wsChatID returns the effective chat ID from a WS message.
// For group messages it is msg.ChatID; for single chats it falls back to the sender's UserID.
func wsChatID(msg WeComAIBotWSMessage) string {
if msg.ChatID != "" {
return msg.ChatID
}
return msg.From.UserID
}
// wsGenerateID generates a random 10-character alphanumeric ID.
// It is package-level (not a method) so it can be shared by both channel modes.
func wsGenerateID() string {
return generateRandomID(10)
}
// ---- Inbound media download helpers ----
// storeWSMedia downloads the resource at resourceURL (with optional AES-CBC
// decryption) and stores it in the MediaStore. The file extension is inferred
// from the HTTP Content-Type response header; defaultExt is used as a fallback
// when the content type is absent or unrecognized.
func (c *WeComAIBotWSChannel) storeWSMedia(
ctx context.Context,
chatID, msgID, resourceURL, aesKey, defaultExt string,
) (string, error) {
store := c.GetMediaStore()
if store == nil {
return "", fmt.Errorf("no media store available")
}
const maxSize = 20 << 20 // 20 MB
req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
resp, err := wsImageHTTPClient.Do(req)
if err != nil {
return "", fmt.Errorf("download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("download HTTP %d", resp.StatusCode)
}
// Infer file extension from the Content-Type response header.
ext := wsMediaExtFromContentType(resp.Header.Get("Content-Type"))
if ext == "" {
ext = defaultExt
}
// Buffer the media in memory, bounded to maxSize.
data, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxSize)+1))
if err != nil {
return "", fmt.Errorf("read media: %w", err)
}
if len(data) > maxSize {
return "", fmt.Errorf("media too large (> %d MB)", maxSize>>20)
}
// AES-CBC decryption if a key is present.
if aesKey != "" {
key, decErr := base64.StdEncoding.DecodeString(aesKey)
if decErr != nil || len(key) != 32 {
key, decErr = decodeWeComAESKey(aesKey)
if decErr != nil {
return "", fmt.Errorf("decode media AES key: %w", decErr)
}
}
data, err = decryptAESCBC(key, data)
if err != nil {
return "", fmt.Errorf("decrypt media: %w", err)
}
}
// Write to a temp file. The file is owned by the MediaStore and deleted by
// store.ReleaseAll — no caller-side cleanup needed.
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err = os.MkdirAll(mediaDir, 0o700); err != nil {
return "", fmt.Errorf("mkdir: %w", err)
}
tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext)
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
_, writeErr := tmpFile.Write(data)
closeErr := tmpFile.Close()
if writeErr != nil {
os.Remove(tmpPath)
return "", fmt.Errorf("write media: %w", writeErr)
}
if closeErr != nil {
os.Remove(tmpPath)
return "", fmt.Errorf("close media: %w", closeErr)
}
scope := channels.BuildMediaScope("wecom_aibot", chatID, msgID)
ref, err := store.Store(tmpPath, media.MediaMeta{
Filename: msgID + ext,
Source: "wecom_aibot",
}, scope)
if err != nil {
os.Remove(tmpPath)
return "", fmt.Errorf("store: %w", err)
}
return ref, nil
}
// wsMediaExtFromContentType returns the lowercase file extension (with leading
// dot) for the given Content-Type value, or "" when the type is unrecognized.
func wsMediaExtFromContentType(contentType string) string {
if contentType == "" {
return ""
}
// Strip parameters (e.g. "image/jpeg; charset=utf-8" → "image/jpeg").
mt := strings.ToLower(strings.TrimSpace(strings.SplitN(contentType, ";", 2)[0]))
switch mt {
case "image/jpeg", "image/jpg":
return ".jpg"
case "image/png":
return ".png"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "video/mp4":
return ".mp4"
case "video/mpeg", "video/x-mpeg":
return ".mpeg"
case "video/quicktime":
return ".mov"
case "video/webm":
return ".webm"
case "audio/mpeg", "audio/mp3":
return ".mp3"
case "audio/ogg":
return ".ogg"
case "audio/wav":
return ".wav"
case "application/pdf":
return ".pdf"
case "application/zip":
return ".zip"
case "application/x-rar-compressed", "application/vnd.rar":
return ".rar"
case "text/plain":
return ".txt"
case "application/msword":
return ".doc"
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return ".docx"
case "application/vnd.ms-excel":
return ".xls"
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
return ".xlsx"
case "application/vnd.ms-powerpoint":
return ".ppt"
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return ".pptx"
}
return ""
}
// wsLabelToDefaultExt returns the default file extension for the given media label
// used in wsHandleMediaMessage. It is the fallback when Content-Type detection fails.
func wsLabelToDefaultExt(label string) string {
switch label {
case "image":
return ".jpg"
case "video":
return ".mp4"
default: // "file" and any future labels
return ".bin"
}
}
// ---- Content length helpers ----
// splitWSContent splits content into chunks each fitting within maxBytes UTF-8
// bytes, preserving code block integrity via channels.SplitMessage.
// When SplitMessage still produces an oversized chunk (e.g. dense CJK content),
// splitAtByteBoundary is applied as a last-resort byte-level fallback.
func splitWSContent(content string, maxBytes int) []string {
if len(content) <= maxBytes {
return []string{content}
}
// SplitMessage works in runes. Use maxBytes as the rune limit: for pure ASCII
// this is exact; for multibyte content the byte verification below catches
// any chunk that still overflows.
chunks := channels.SplitMessage(content, maxBytes)
var result []string
for _, chunk := range chunks {
if len(chunk) <= maxBytes {
result = append(result, chunk)
} else {
// Still too large in bytes (e.g. dense CJK); force-split at UTF-8 boundaries.
result = append(result, splitAtByteBoundary(chunk, maxBytes)...)
}
}
return result
}
// splitAtByteBoundary splits s into parts each ≤ maxBytes bytes by walking back
// from the hard byte limit to find a valid UTF-8 rune start boundary.
// This is a last-resort fallback; it does not try to preserve code blocks.
func splitAtByteBoundary(s string, maxBytes int) []string {
var parts []string
for len(s) > maxBytes {
end := maxBytes
// Walk back past any UTF-8 continuation bytes (high two bits == 10).
for end > 0 && s[end]>>6 == 0b10 {
end--
}
if end == 0 {
end = maxBytes // shouldn't happen with valid UTF-8
}
parts = append(parts, s[:end])
s = strings.TrimLeft(s[end:], " \t\n\r")
}
if s != "" {
parts = append(parts, s)
}
return parts
}
================================================
FILE: pkg/channels/wecom/aibot_ws_test.go
================================================
package wecom
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing.
func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel {
t.Helper()
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
}
ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("create WS channel: %v", err)
}
return ch
}
// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no
// MediaStore has been injected.
func TestStoreWSMedia_NilStore(t *testing.T) {
ch := newTestWSChannel(t)
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg")
if err == nil {
t.Fatal("expected error when no MediaStore is set")
}
}
// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors
// from the media server.
func TestStoreWSMedia_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
}))
defer srv.Close()
ch := newTestWSChannel(t)
ch.SetMediaStore(media.NewFileMediaStore())
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
if err == nil {
t.Fatal("expected error for HTTP 404")
}
}
// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear
// error when the media server cannot be reached.
func TestStoreWSMedia_ServerUnavailable(t *testing.T) {
ch := newTestWSChannel(t)
ch.SetMediaStore(media.NewFileMediaStore())
// Port 1 is reserved and will refuse the connection immediately.
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg")
if err == nil {
t.Fatal("expected error for unreachable server")
}
}
// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded,
// a media ref is returned, and the file persists and is readable via Resolve until
// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used.
func TestStoreWSMedia_Success_NoAES(t *testing.T) {
imageData := bytes.Repeat([]byte("x"), 256)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageData)
}))
defer srv.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ref == "" {
t.Fatal("expected non-empty ref")
}
// File must be accessible after storeWSMedia returns (no premature deletion).
path, err := store.Resolve(ref)
if err != nil {
t.Fatalf("ref should resolve: %v", err)
}
got, err := os.ReadFile(path)
if err != nil {
t.Fatalf("file should exist at %s: %v", path, err)
}
if !bytes.Equal(got, imageData) {
t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData))
}
// ReleaseAll must delete the file (store owns lifecycle).
scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1")
if err := store.ReleaseAll(scope); err != nil {
t.Fatalf("ReleaseAll failed: %v", err)
}
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err)
}
}
// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with
// different msgIDs do not collide and each resolve to distinct files.
func TestStoreWSMedia_MultipleMessages(t *testing.T) {
imageA := bytes.Repeat([]byte("a"), 64)
imageB := bytes.Repeat([]byte("b"), 64)
srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageA)
}))
defer srvA.Close()
srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageB)
}))
defer srvB.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg")
if err != nil {
t.Fatalf("storeWSMedia A: %v", err)
}
refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg")
if err != nil {
t.Fatalf("storeWSMedia B: %v", err)
}
if refA == refB {
t.Fatal("distinct messages must produce distinct refs")
}
pathA, _ := store.Resolve(refA)
pathB, _ := store.Resolve(refB)
if pathA == pathB {
t.Fatal("distinct messages must be stored at distinct paths")
}
gotA, _ := os.ReadFile(pathA)
gotB, _ := os.ReadFile(pathB)
if !bytes.Equal(gotA, imageA) {
t.Errorf("content mismatch for message A")
}
if !bytes.Equal(gotB, imageB) {
t.Errorf("content mismatch for message B")
}
}
// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred
// from the HTTP Content-Type header and the defaultExt fallback is used when the
// type is absent or unrecognized.
func TestStoreWSMedia_ContentTypeExt(t *testing.T) {
tests := []struct {
contentType string
wantExt string
}{
{"image/jpeg", ".jpg"},
{"image/png", ".png"},
{"video/mp4", ".mp4"},
{"application/pdf", ".pdf"},
{"application/zip", ".zip"},
// With parameters stripped.
{"video/mp4; codecs=avc1", ".mp4"},
// Unknown type → falls back to defaultExt.
{"", ""},
{"application/octet-stream", ""},
}
for _, tc := range tests {
got := wsMediaExtFromContentType(tc.contentType)
if got != tc.wantExt {
t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt)
}
}
// End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin.
// The stored file should carry the .mp4 extension, not .bin.
payload := bytes.Repeat([]byte("v"), 128)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "video/mp4")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(payload)
}))
defer srv.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin")
if err != nil {
t.Fatalf("storeWSMedia: %v", err)
}
path, err := store.Resolve(ref)
if err != nil {
t.Fatalf("resolve: %v", err)
}
if ext := path[len(path)-4:]; ext != ".mp4" {
t.Errorf("expected .mp4 extension from Content-Type, got %q", ext)
}
}
// TestSplitWSContent verifies byte-aware splitting of stream content.
func TestSplitWSContent(t *testing.T) {
t.Run("short content is not split", func(t *testing.T) {
chunks := splitWSContent("hello", 20480)
if len(chunks) != 1 || chunks[0] != "hello" {
t.Fatalf("unexpected chunks: %v", chunks)
}
})
t.Run("ASCII content split at byte boundary", func(t *testing.T) {
// Build a string just over the limit.
content := strings.Repeat("a", 20481)
chunks := splitWSContent(content, 20480)
if len(chunks) < 2 {
t.Fatalf("expected >= 2 chunks, got %d", len(chunks))
}
for i, c := range chunks {
if len(c) > 20480 {
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
}
}
// Reassembled content must equal the original (possibly without leading
// whitespace that splitWSContent trims between chunks).
joined := strings.Join(chunks, "")
if len(joined) < len(content)-len(chunks) {
t.Errorf("joined length %d too short (original %d)", len(joined), len(content))
}
})
t.Run("CJK content split within byte limit", func(t *testing.T) {
// Each CJK rune is 3 bytes in UTF-8.
// 7000 CJK chars = 21000 bytes, which exceeds 20480.
content := strings.Repeat("\u4e2d", 7000)
chunks := splitWSContent(content, 20480)
if len(chunks) < 2 {
t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks))
}
for i, c := range chunks {
if len(c) > 20480 {
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
}
// Every chunk must be valid UTF-8.
if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 {
// quick plausibility check — content was pure CJK
}
}
})
}
// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter.
func TestSplitAtByteBoundary(t *testing.T) {
t.Run("ASCII fits in one chunk", func(t *testing.T) {
parts := splitAtByteBoundary("hello world", 100)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d", len(parts))
}
})
t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) {
// 10 CJK characters = 30 bytes; split at 20 bytes.
s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes
parts := splitAtByteBoundary(s, 20)
for i, p := range parts {
if len(p) > 20 {
t.Errorf("part %d has %d bytes, want <= 20", i, len(p))
}
// Must be valid UTF-8 (no torn multi-byte sequences).
for j, r := range p {
if r == '\uFFFD' {
t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j)
}
}
}
})
}
================================================
FILE: pkg/channels/wecom/app.go
================================================
package wecom
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
const (
wecomAPIBase = "https://qyapi.weixin.qq.com"
)
// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
type WeComAppChannel struct {
*channels.BaseChannel
config config.WeComAppConfig
client *http.Client
accessToken string
tokenExpiry time.Time
tokenMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
processedMsgs *MessageDeduplicator
}
// WeComXMLMessage represents the XML message structure from WeCom
type WeComXMLMessage struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
MsgId int64 `xml:"MsgId"`
AgentID int64 `xml:"AgentID"`
PicUrl string `xml:"PicUrl"`
MediaId string `xml:"MediaId"`
Format string `xml:"Format"`
ThumbMediaId string `xml:"ThumbMediaId"`
LocationX float64 `xml:"Location_X"`
LocationY float64 `xml:"Location_Y"`
Scale int `xml:"Scale"`
Label string `xml:"Label"`
Title string `xml:"Title"`
Description string `xml:"Description"`
Url string `xml:"Url"`
Event string `xml:"Event"`
EventKey string `xml:"EventKey"`
}
// WeComTextMessage represents text message for sending
type WeComTextMessage struct {
ToUser string `json:"touser"`
MsgType string `json:"msgtype"`
AgentID int64 `json:"agentid"`
Text struct {
Content string `json:"content"`
} `json:"text"`
Safe int `json:"safe,omitempty"`
}
// WeComMarkdownMessage represents markdown message for sending
type WeComMarkdownMessage struct {
ToUser string `json:"touser"`
MsgType string `json:"msgtype"`
AgentID int64 `json:"agentid"`
Markdown struct {
Content string `json:"content"`
} `json:"markdown"`
}
// WeComImageMessage represents image message for sending
type WeComImageMessage struct {
ToUser string `json:"touser"`
MsgType string `json:"msgtype"`
AgentID int64 `json:"agentid"`
Image struct {
MediaID string `json:"media_id"`
} `json:"image"`
}
// WeComAccessTokenResponse represents the access token API response
type WeComAccessTokenResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}
// WeComSendMessageResponse represents the send message API response
type WeComSendMessageResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
InvalidUser string `json:"invaliduser"`
InvalidParty string `json:"invalidparty"`
InvalidTag string `json:"invalidtag"`
}
// PKCS7Padding adds PKCS7 padding
type PKCS7Padding struct{}
// NewWeComAppChannel creates a new WeCom App channel instance
func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) {
if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 {
return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
}
base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom,
channels.WithMaxMessageLength(2048),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComAppChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
// Name returns the channel name
func (c *WeComAppChannel) Name() string {
return "wecom_app"
}
// Start initializes the WeCom App channel
func (c *WeComAppChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_app", "Starting WeCom App channel...")
// Cancel the context created in the constructor to avoid a resource leak.
if c.cancel != nil {
c.cancel()
}
c.ctx, c.cancel = context.WithCancel(ctx)
// Get initial access token
if err := c.refreshAccessToken(); err != nil {
logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{
"error": err.Error(),
})
}
// Start token refresh goroutine
go c.tokenRefreshLoop()
c.SetRunning(true)
logger.InfoC("wecom_app", "WeCom App channel started")
return nil
}
// Stop gracefully stops the WeCom App channel
func (c *WeComAppChannel) Stop(ctx context.Context) error {
logger.InfoC("wecom_app", "Stopping WeCom App channel...")
if c.cancel != nil {
c.cancel()
}
c.SetRunning(false)
logger.InfoC("wecom_app", "WeCom App channel stopped")
return nil
}
// Send sends a message to WeCom user proactively using access token
func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
accessToken := c.getAccessToken()
if accessToken == "" {
return fmt.Errorf("no valid access token available")
}
logger.DebugCF("wecom_app", "Sending message", map[string]any{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
}
// SendMedia implements the channels.MediaSender interface.
func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
accessToken := c.getAccessToken()
if accessToken == "" {
return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary)
}
store := c.GetMediaStore()
if store == nil {
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
for _, part := range msg.Parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
continue
}
// Map part type to WeCom media type
var mediaType string
switch part.Type {
case "image":
mediaType = "image"
case "audio":
mediaType = "voice"
case "video":
mediaType = "video"
default:
mediaType = "file"
}
// Upload media to get media_id
mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{
"type": mediaType,
"error": err.Error(),
})
// Fallback: send caption as text
if part.Caption != "" {
_ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption)
}
continue
}
// Send media message using the media_id
if mediaType == "image" {
err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID)
} else {
// For non-image types, send as text fallback with caption
caption := part.Caption
if caption == "" {
caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename)
}
err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption)
}
if err != nil {
return err
}
}
return nil
}
// uploadMedia uploads a local file to WeCom temporary media storage.
func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) {
apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s",
wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType))
file, err := os.Open(localPath)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
filename := filepath.Base(localPath)
formFile, err := writer.CreateFormFile("media", filename)
if err != nil {
return "", fmt.Errorf("failed to create form file: %w", err)
}
if _, err = io.Copy(formFile, file); err != nil {
return "", fmt.Errorf("failed to copy file content: %w", err)
}
writer.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := c.client.Do(req)
if err != nil {
return "", channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom upload error response: %w", readErr),
)
}
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom upload error: %s", string(respBody)),
)
}
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
MediaID string `json:"media_id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse upload response: %w", err)
}
if result.ErrCode != 0 {
return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
}
return result.MediaID, nil
}
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
timeout := c.config.ReplyTimeout
if timeout <= 0 {
timeout = 5
}
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom_app error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom_app API error: %s", string(respBody)),
)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
var sendResp WeComSendMessageResponse
if err := json.Unmarshal(respBody, &sendResp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if sendResp.ErrCode != 0 {
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
}
return nil
}
// sendImageMessage sends an image message using a media_id.
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
msg := WeComImageMessage{
ToUser: userID,
MsgType: "image",
AgentID: c.config.AgentID,
}
msg.Image.MediaID = mediaID
return c.sendWeComMessage(ctx, accessToken, msg)
}
// WebhookPath returns the path for registering on the shared HTTP server.
func (c *WeComAppChannel) WebhookPath() string {
if c.config.WebhookPath != "" {
return c.config.WebhookPath
}
return "/webhook/wecom-app"
}
// ServeHTTP implements http.Handler for the shared HTTP server.
func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.handleWebhook(w, r)
}
// HealthPath returns the health check endpoint path.
func (c *WeComAppChannel) HealthPath() string {
return "/health/wecom-app"
}
// HealthHandler handles health check requests.
func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
c.handleHealth(w, r)
}
// handleWebhook handles incoming webhook requests from WeCom
func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Log all incoming requests for debugging
logger.DebugCF("wecom_app", "Received webhook request", map[string]any{
"method": r.Method,
"url": r.URL.String(),
"path": r.URL.Path,
"query": r.URL.RawQuery,
})
if r.Method == http.MethodGet {
// Handle verification request
c.handleVerification(ctx, w, r)
return
}
if r.Method == http.MethodPost {
// Handle message callback
c.handleMessageCallback(ctx, w, r)
return
}
logger.WarnCF("wecom_app", "Method not allowed", map[string]any{
"method": r.Method,
})
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// handleVerification handles the URL verification request from WeCom
func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
msgSignature := query.Get("msg_signature")
timestamp := query.Get("timestamp")
nonce := query.Get("nonce")
echostr := query.Get("echostr")
logger.DebugCF("wecom_app", "Handling verification request", map[string]any{
"msg_signature": msgSignature,
"timestamp": timestamp,
"nonce": nonce,
"echostr": echostr,
"corp_id": c.config.CorpID,
})
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
logger.ErrorC("wecom_app", "Missing parameters in verification request")
http.Error(w, "Missing parameters", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
"token": c.config.Token,
"msg_signature": msgSignature,
"timestamp": timestamp,
"nonce": nonce,
})
http.Error(w, "Invalid signature", http.StatusForbidden)
return
}
logger.DebugC("wecom_app", "Signature verification passed")
// Decrypt echostr with CorpID verification
// For WeCom App (自建应用), receiveid should be corp_id
logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{
"encoding_aes_key": c.config.EncodingAESKey,
"corp_id": c.config.CorpID,
})
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
"error": err.Error(),
"encoding_aes_key": c.config.EncodingAESKey,
"corp_id": c.config.CorpID,
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{
"decrypted": decryptedEchoStr,
})
// Remove BOM and whitespace as per WeCom documentation
// The response must be plain text without quotes, BOM, or newlines
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
w.Write([]byte(decryptedEchoStr))
}
// handleMessageCallback handles incoming messages from WeCom
func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
msgSignature := query.Get("msg_signature")
timestamp := query.Get("timestamp")
nonce := query.Get("nonce")
if msgSignature == "" || timestamp == "" || nonce == "" {
http.Error(w, "Missing parameters", http.StatusBadRequest)
return
}
// Read request body
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read body", http.StatusBadRequest)
return
}
defer r.Body.Close()
// Parse XML to get encrypted message
var encryptedMsg struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
Encrypt string `xml:"Encrypt"`
AgentID string `xml:"AgentID"`
}
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{
"error": err.Error(),
})
http.Error(w, "Invalid XML", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
logger.WarnC("wecom_app", "Message signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
}
// Decrypt message with CorpID verification
// For WeCom App (自建应用), receiveid should be corp_id
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
"error": err.Error(),
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
// Parse decrypted XML message
var msg WeComXMLMessage
if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{
"error": err.Error(),
})
http.Error(w, "Invalid message format", http.StatusBadRequest)
return
}
// Process the message with the channel's long-lived context (not the HTTP
// request context, which is canceled as soon as we return the response).
go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom App requires response within configured timeout (default 5 seconds)
w.Write([]byte("success"))
}
// processMessage processes the received message
func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) {
// Skip non-text messages for now (can be extended)
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" {
logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{
"msg_type": msg.MsgType,
})
return
}
// Message deduplication: Use msg_id to prevent duplicate processing
// As per WeCom documentation, use msg_id for deduplication
msgID := fmt.Sprintf("%d", msg.MsgId)
if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
senderID := msg.FromUserName
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
// Build metadata
// WeCom App only supports direct messages (private chat)
peer := bus.Peer{Kind: "direct", ID: senderID}
messageID := fmt.Sprintf("%d", msg.MsgId)
metadata := map[string]string{
"msg_type": msg.MsgType,
"msg_id": fmt.Sprintf("%d", msg.MsgId),
"agent_id": fmt.Sprintf("%d", msg.AgentID),
"platform": "wecom_app",
"media_id": msg.MediaId,
"create_time": fmt.Sprintf("%d", msg.CreateTime),
}
content := msg.Content
logger.DebugCF("wecom_app", "Received message", map[string]any{
"sender_id": senderID,
"msg_type": msg.MsgType,
"preview": utils.Truncate(content, 50),
})
// Build sender info
appSender := bus.SenderInfo{
Platform: "wecom",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
}
// Handle the message through the base channel
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender)
}
// tokenRefreshLoop periodically refreshes the access token
func (c *WeComAppChannel) tokenRefreshLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
if err := c.refreshAccessToken(); err != nil {
logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{
"error": err.Error(),
})
}
}
}
}
// refreshAccessToken gets a new access token from WeCom API
func (c *WeComAppChannel) refreshAccessToken() error {
apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret))
resp, err := http.Get(apiURL)
if err != nil {
return fmt.Errorf("failed to request access token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
var tokenResp WeComAccessTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if tokenResp.ErrCode != 0 {
return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode)
}
c.tokenMu.Lock()
c.accessToken = tokenResp.AccessToken
c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early
c.tokenMu.Unlock()
logger.DebugC("wecom_app", "Access token refreshed successfully")
return nil
}
// getAccessToken returns the current valid access token
func (c *WeComAppChannel) getAccessToken() string {
c.tokenMu.RLock()
defer c.tokenMu.RUnlock()
if time.Now().After(c.tokenExpiry) {
return ""
}
return c.accessToken
}
// sendTextMessage sends a text message to a user.
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
msg := WeComTextMessage{
ToUser: userID,
MsgType: "text",
AgentID: c.config.AgentID,
}
msg.Text.Content = content
return c.sendWeComMessage(ctx, accessToken, msg)
}
// handleHealth handles health check requests
func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
status := map[string]any{
"status": "ok",
"running": c.IsRunning(),
"has_token": c.getAccessToken() != "",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
================================================
FILE: pkg/channels/wecom/app_test.go
================================================
package wecom
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
// generateTestAESKeyApp generates a valid test AES key for WeCom App
func generateTestAESKeyApp() string {
// AES key needs to be 32 bytes (256 bits) for AES-256
key := make([]byte, 32)
for i := range key {
key[i] = byte(i + 1)
}
// Return base64 encoded key without padding
return base64.StdEncoding.EncodeToString(key)[:43]
}
// encryptTestMessageApp encrypts a message for testing WeCom App
func encryptTestMessageApp(message, aesKey string) (string, error) {
// Decode AES key
key, err := base64.StdEncoding.DecodeString(aesKey + "=")
if err != nil {
return "", err
}
// Prepare message: random(16) + msg_len(4) + msg + corp_id
random := make([]byte, 0, 16)
for i := range 16 {
random = append(random, byte(i+1))
}
msgBytes := []byte(message)
corpID := []byte("test_corp_id")
msgLen := uint32(len(msgBytes))
lenBytes := make([]byte, 4)
binary.BigEndian.PutUint32(lenBytes, msgLen)
plainText := append(random, lenBytes...)
plainText = append(plainText, msgBytes...)
plainText = append(plainText, corpID...)
// PKCS7 padding
blockSize := aes.BlockSize
padding := blockSize - len(plainText)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
plainText = append(plainText, padText...)
// Encrypt
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
cipherText := make([]byte, len(plainText))
mode.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// generateSignatureApp generates a signature for testing WeCom App
func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string {
params := []string{token, timestamp, nonce, msgEncrypt}
sort.Strings(params)
str := strings.Join(params, "")
hash := sha1.Sum([]byte(str))
return fmt.Sprintf("%x", hash)
}
func TestNewWeComAppChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing corp_id", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "",
CorpSecret: "test_secret",
AgentID: 1000002,
}
_, err := NewWeComAppChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing corp_id, got nil")
}
})
t.Run("missing corp_secret", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "",
AgentID: 1000002,
}
_, err := NewWeComAppChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing corp_secret, got nil")
}
})
t.Run("missing agent_id", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 0,
}
_, err := NewWeComAppChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing agent_id, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
AllowFrom: []string{"user1", "user2"},
}
ch, err := NewWeComAppChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "wecom_app" {
t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestWeComAppChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
AllowFrom: []string{},
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
if !ch.IsAllowed("any_user") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
AllowFrom: []string{"allowed_user"},
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
if !ch.IsAllowed("allowed_user") {
t.Error("allowed user should pass allowlist check")
}
if ch.IsAllowed("blocked_user") {
t.Error("non-allowed user should be blocked")
}
})
}
func TestWeComAppVerifySignature(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
Token: "test_token",
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
t.Run("valid signature", func(t *testing.T) {
timestamp := "1234567890"
nonce := "test_nonce"
msgEncrypt := "test_message"
expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt)
if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
t.Error("valid signature should pass verification")
}
})
t.Run("invalid signature", func(t *testing.T) {
timestamp := "1234567890"
nonce := "test_nonce"
msgEncrypt := "test_message"
if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
t.Error("invalid signature should fail verification")
}
})
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
cfgEmpty := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
Token: "",
}
chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus)
if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should reject verification (fail-closed)")
}
})
}
func TestWeComAppDecryptMessage(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("decrypt without AES key", func(t *testing.T) {
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
EncodingAESKey: "",
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
// Without AES key, message should be base64 decoded only
plainText := "hello world"
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
result, err := decryptMessage(encoded, ch.config.EncodingAESKey)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != plainText {
t.Errorf("decryptMessage() = %q, want %q", result, plainText)
}
})
t.Run("decrypt with AES key", func(t *testing.T) {
aesKey := generateTestAESKeyApp()
cfg := config.WeComAppConfig{
CorpID: "test_corp_id",
CorpSecret: "test_secret",
AgentID: 1000002,
EncodingAESKey: aesKey,
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
originalMsg := "Content
`, ), ) })) defer server.Close() tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } ctx := context.Background() args := map[string]any{ "url": server.URL, } result := tool.Execute(ctx, args) // Success should not be an error if result.IsError { t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) } // ForLLM should contain extracted text (without script/style tags) if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") { t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM) } // Should NOT contain script or style tags in ForLLM if strings.Contains(result.ForLLM, "Keep this
", wantFunc: func(t *testing.T, got string) { if strings.Contains(got, "alert") || strings.Contains(got, "body{}") { t.Errorf("Expected script/style content removed, got: %q", got) } if !strings.Contains(got, "Keep this") { t.Errorf("Expected 'Keep this' to remain, got: %q", got) } }, }, { name: "collapses excessive blank lines", input: "A
\n\n\n\n\nB
", wantFunc: func(t *testing.T, got string) { if strings.Contains(got, "\n\n\n") { t.Errorf("Expected excessive blank lines collapsed, got: %q", got) } }, }, { name: "collapses horizontal whitespace", input: "hello world
", wantFunc: func(t *testing.T, got string) { if strings.Contains(got, " ") { t.Errorf("Expected spaces collapsed, got: %q", got) } if !strings.Contains(got, "hello world") { t.Errorf("Expected 'hello world', got: %q", got) } }, }, { name: "empty input", input: "", wantFunc: func(t *testing.T, got string) { if got != "" { t.Errorf("Expected empty string, got: %q", got) } }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tool.extractText(tt.input) tt.wantFunc(t, got) }) } } func withPrivateWebFetchHostsAllowed(t *testing.T) { t.Helper() previous := allowPrivateWebFetchHosts.Load() allowPrivateWebFetchHosts.Store(true) t.Cleanup(func() { allowPrivateWebFetchHosts.Store(previous) }) } func serverHostAndPort(t *testing.T, rawURL string) (string, string) { t.Helper() hostPort := strings.TrimPrefix(rawURL, "http://") hostPort = strings.TrimPrefix(hostPort, "https://") host, port, err := net.SplitHostPort(hostPort) if err != nil { t.Fatalf("failed to split host/port from %q: %v", rawURL, err) } return host, port } func singleHostCIDR(t *testing.T, host string) string { t.Helper() ip := net.ParseIP(host) if ip == nil { t.Fatalf("failed to parse IP %q", host) } if ip.To4() != nil { return ip.String() + "/32" } return ip.String() + "/128" } func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": "http://127.0.0.1:0", }) if !result.IsError { t.Errorf("expected error for private host URL, got success") } if !strings.Contains(result.ForLLM, "private or local network") && !strings.Contains(result.ForUser, "private or local network") { t.Errorf("expected private host block message, got %q", result.ForLLM) } } func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte("exact whitelist ok")) })) defer server.Close() host, _ := serverHostAndPort(t, server.URL) tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{host}) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": server.URL, }) if result.IsError { t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM) } if !strings.Contains(result.ForLLM, "exact whitelist ok") { t.Fatalf("expected fetched content, got %q", result.ForLLM) } } func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte("cidr whitelist ok")) })) defer server.Close() host, _ := serverHostAndPort(t, server.URL) tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)}) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": server.URL, }) if result.IsError { t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM) } if !strings.Contains(result.ForLLM, "cidr whitelist ok") { t.Fatalf("expected fetched content, got %q", result.ForLLM) } } func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) { withPrivateWebFetchHostsAllowed(t) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) })) defer server.Close() tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": server.URL, }) if result.IsError { t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM) } } // TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": "http://[::ffff:127.0.0.1]:0", }) if !result.IsError { t.Error("expected error for IPv4-mapped IPv6 loopback URL, got success") } } // TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked func TestWebFetch_BlocksMetadataIP(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": "http://169.254.169.254/latest/meta-data", }) if !result.IsError { t.Error("expected error for cloud metadata IP, got success") } } // TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": "http://[fd00::1]:0", }) if !result.IsError { t.Error("expected error for IPv6 unique local address, got success") } } // TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } // 2002:7f00:0001::1 embeds 127.0.0.1 result := tool.Execute(context.Background(), map[string]any{ "url": "http://[2002:7f00:0001::1]:0", }) if !result.IsError { t.Error("expected error for 6to4 with private embedded IPv4, got success") } } // TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } // 2002:0801:0101::1 embeds 8.1.1.1 (public) — pre-flight should pass, // connection will fail (no listener) but that's after the SSRF check. result := tool.Execute(context.Background(), map[string]any{ "url": "http://[2002:0801:0101::1]:0", }) // Should NOT be blocked by SSRF check — error should be connection failure, not "private" if result.IsError && strings.Contains(result.ForLLM, "private") { t.Error("6to4 with public embedded IPv4 should not be blocked as private") } } // TestWebFetch_RedirectToPrivateBlocked verifies redirects to private IPs are blocked func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) { withPrivateWebFetchHostsAllowed(t) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Redirect to a private IP http.Redirect(w, r, "http://10.0.0.1/secret", http.StatusFound) })) defer server.Close() // Temporarily disable private host allowance for the redirect check allowPrivateWebFetchHosts.Store(false) defer allowPrivateWebFetchHosts.Store(true) tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("Failed to create web fetch tool: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "url": server.URL, }) if !result.IsError { t.Error("expected error when redirecting to private IP, got success") } } func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on loopback: %v", err) } defer listener.Close() _, port, err := net.SplitHostPort(listener.Addr().String()) if err != nil { t.Fatalf("failed to split listener address: %v", err) } dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil) _, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port)) if err == nil { t.Fatal("expected localhost DNS resolution to be blocked without whitelist") } if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") { t.Fatalf("unexpected error: %v", err) } } func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on loopback: %v", err) } defer listener.Close() accepted := make(chan struct{}, 1) go func() { conn, acceptErr := listener.Accept() if acceptErr != nil { return } conn.Close() accepted <- struct{}{} }() _, port, err := net.SplitHostPort(listener.Addr().String()) if err != nil { t.Fatalf("failed to split listener address: %v", err) } whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"}) if err != nil { t.Fatalf("failed to parse whitelist: %v", err) } dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist) conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port)) if err != nil { t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err) } conn.Close() select { case <-accepted: case <-time.After(time.Second): t.Fatal("expected localhost listener to accept a connection") } } // TestIsPrivateOrRestrictedIP_Table tests IP classification logic func TestIsPrivateOrRestrictedIP_Table(t *testing.T) { tests := []struct { ip string blocked bool desc string }{ {"127.0.0.1", true, "IPv4 loopback"}, {"10.0.0.1", true, "IPv4 private class A"}, {"172.16.0.1", true, "IPv4 private class B"}, {"192.168.1.1", true, "IPv4 private class C"}, {"169.254.169.254", true, "link-local / cloud metadata"}, {"100.64.0.1", true, "carrier-grade NAT"}, {"0.0.0.0", true, "unspecified"}, {"8.8.8.8", false, "public DNS"}, {"1.1.1.1", false, "public DNS"}, {"::1", true, "IPv6 loopback"}, {"::ffff:127.0.0.1", true, "IPv4-mapped IPv6 loopback"}, {"::ffff:10.0.0.1", true, "IPv4-mapped IPv6 private"}, {"fc00::1", true, "IPv6 unique local"}, {"fd00::1", true, "IPv6 unique local"}, {"2002:7f00:0001::1", true, "6to4 with embedded 127.x (private)"}, {"2002:0a00:0001::1", true, "6to4 with embedded 10.0.0.1 (private)"}, {"2002:0801:0101::1", false, "6to4 with embedded 8.1.1.1 (public)"}, {"2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, "Teredo with client 10.0.0.1 (private)"}, {"2001:0000:4136:e378:8000:63bf:f7f6:fefe", false, "Teredo with client 8.9.1.1 (public)"}, {"2607:f8b0:4004:800::200e", false, "public IPv6 (Google)"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { t.Fatalf("failed to parse IP: %s", tt.ip) } got := isPrivateOrRestrictedIP(ip) if got != tt.blocked { t.Errorf("isPrivateOrRestrictedIP(%s) = %v, want %v", tt.ip, got, tt.blocked) } }) } } // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } ctx := context.Background() args := map[string]any{ "url": "https://", } result := tool.Execute(ctx, args) // Should return error result if !result.IsError { t.Errorf("Expected error for URL without domain") } // Should mention missing domain if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") { t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM) } } func TestNewWebFetchToolWithProxy(t *testing.T) { tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil) if err != nil { logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } else if tool.maxChars != 1024 { t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024) } if tool.proxy != "http://127.0.0.1:7890" { t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890") } tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil) if err != nil { logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } if tool.maxChars != 50000 { t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) } } func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) { _, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"}) if err == nil { t.Fatal("expected invalid whitelist entry to fail") } if !strings.Contains(err.Error(), "invalid entry") { t.Fatalf("unexpected error: %v", err) } } func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { t.Run("perplexity", func(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{ PerplexityEnabled: true, PerplexityAPIKeys: []string{"k"}, PerplexityMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } p, ok := tool.provider.(*PerplexitySearchProvider) if !ok { t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider) } if p.proxy != "http://127.0.0.1:7890" { t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") } }) t.Run("brave", func(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{ BraveEnabled: true, BraveAPIKeys: []string{"k"}, BraveMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } p, ok := tool.provider.(*BraveSearchProvider) if !ok { t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider) } if p.proxy != "http://127.0.0.1:7890" { t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") } }) t.Run("duckduckgo", func(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{ DuckDuckGoEnabled: true, DuckDuckGoMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } p, ok := tool.provider.(*DuckDuckGoSearchProvider) if !ok { t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider) } if p.proxy != "http://127.0.0.1:7890" { t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") } }) } // TestWebTool_TavilySearch_Success verifies successful Tavily search func TestWebTool_TavilySearch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) } if r.Header.Get("Content-Type") != "application/json" { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } // Verify payload var payload map[string]any json.NewDecoder(r.Body).Decode(&payload) if payload["api_key"] != "test-key" { t.Errorf("Expected api_key test-key, got %v", payload["api_key"]) } if payload["query"] != "test query" { t.Errorf("Expected query 'test query', got %v", payload["query"]) } // Return mock response response := map[string]any{ "results": []map[string]any{ { "title": "Test Result 1", "url": "https://example.com/1", "content": "Content for result 1", }, { "title": "Test Result 2", "url": "https://example.com/2", "content": "Content for result 2", }, }, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(response) })) defer server.Close() tool, err := NewWebSearchTool(WebSearchToolOptions{ TavilyEnabled: true, TavilyAPIKeys: []string{"test-key"}, TavilyBaseURL: server.URL, TavilyMaxResults: 5, }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } ctx := context.Background() args := map[string]any{ "query": "test query", } result := tool.Execute(ctx, args) // Success should not be an error if result.IsError { t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) } // ForUser should contain result titles and URLs if !strings.Contains(result.ForUser, "Test Result 1") || !strings.Contains(result.ForUser, "https://example.com/1") { t.Errorf("Expected results in output, got: %s", result.ForUser) } // Should mention via Tavily if !strings.Contains(result.ForUser, "via Tavily") { t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser) } } // TestWebFetchTool_CloudflareChallenge_RetryWithHonestUA verifies that a 403 response // with cf-mitigated: challenge triggers a retry using the honest picoclaw User-Agent, // and that the retry response is returned when it succeeds. func TestWebFetchTool_CloudflareChallenge_RetryWithHonestUA(t *testing.T) { withPrivateWebFetchHostsAllowed(t) requestCount := 0 var receivedUAs []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestCount++ receivedUAs = append(receivedUAs, r.Header.Get("User-Agent")) if requestCount == 1 { // First request: simulate Cloudflare challenge w.Header().Set("Cf-Mitigated", "challenge") w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusForbidden) w.Write([]byte("Cloudflare challenge")) return } // Second request (honest UA retry): success w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte("real content")) })) defer server.Close() tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("NewWebFetchTool() error: %v", err) } result := tool.Execute(context.Background(), map[string]any{"url": server.URL}) if result.IsError { t.Fatalf("expected success after retry, got error: %s", result.ForLLM) } if !strings.Contains(result.ForLLM, "real content") { t.Errorf("expected retry response content, got: %s", result.ForLLM) } if requestCount != 2 { t.Errorf("expected exactly 2 requests, got %d", requestCount) } // First request must use the generic user agent if receivedUAs[0] != userAgent { t.Errorf("first request UA = %q, want %q", receivedUAs[0], userAgent) } // Second request must use the honest picoclaw user agent if !strings.Contains(receivedUAs[1], "picoclaw") { t.Errorf("retry request UA = %q, want it to contain 'picoclaw'", receivedUAs[1]) } } // TestWebFetchTool_CloudflareChallenge_NoRetryOnOtherErrors verifies that a plain 403 // (without cf-mitigated: challenge) does NOT trigger a retry. func TestWebFetchTool_CloudflareChallenge_NoRetryOnOtherErrors(t *testing.T) { withPrivateWebFetchHostsAllowed(t) requestCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestCount++ w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusForbidden) w.Write([]byte("plain forbidden")) })) defer server.Close() tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("NewWebFetchTool() error: %v", err) } tool.Execute(context.Background(), map[string]any{"url": server.URL}) if requestCount != 1 { t.Errorf("expected exactly 1 request for plain 403, got %d", requestCount) } } // TestWebFetchTool_CloudflareChallenge_RetryFailsToo verifies that if the honest-UA // retry also fails (e.g. still blocked), the error from the retry is returned. func TestWebFetchTool_CloudflareChallenge_RetryFailsToo(t *testing.T) { withPrivateWebFetchHostsAllowed(t) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Always return CF challenge regardless of UA w.Header().Set("Cf-Mitigated", "challenge") w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusForbidden) w.Write([]byte("still blocked")) })) defer server.Close() tool, err := NewWebFetchTool(50000, format, testFetchLimit) if err != nil { t.Fatalf("NewWebFetchTool() error: %v", err) } result := tool.Execute(context.Background(), map[string]any{"url": server.URL}) // Should not be an error — the retry response is used as-is (403 is a valid HTTP response) if result.IsError { t.Fatalf("expected non-error result even when retry is also blocked, got: %s", result.ForLLM) } // Status in the JSON result should reflect the 403 if !strings.Contains(result.ForLLM, "403") { t.Errorf("expected status 403 in result, got: %s", result.ForLLM) } } func TestAPIKeyPool(t *testing.T) { pool := NewAPIKeyPool([]string{"key1", "key2", "key3"}) if len(pool.keys) != 3 { t.Fatalf("expected 3 keys, got %d", len(pool.keys)) } if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" { t.Fatalf("unexpected keys: %v", pool.keys) } // Test Iterator: each iterator should cover all keys exactly once iter := pool.NewIterator() expected := []string{"key1", "key2", "key3"} for i, want := range expected { k, ok := iter.Next() if !ok { t.Fatalf("iter.Next() returned false at step %d", i) } if k != want { t.Errorf("step %d: expected %s, got %s", i, want, k) } } // Should be exhausted if _, ok := iter.Next(); ok { t.Errorf("expected iterator exhausted after all keys") } // Second iterator starts at next position (load balancing) iter2 := pool.NewIterator() k, ok := iter2.Next() if !ok { t.Fatal("iter2.Next() returned false") } if k != "key2" { t.Errorf("expected key2 (round-robin), got %s", k) } // Empty pool emptyPool := NewAPIKeyPool([]string{}) emptyIter := emptyPool.NewIterator() if _, ok := emptyIter.Next(); ok { t.Errorf("expected false for empty pool") } // Single key pool singlePool := NewAPIKeyPool([]string{"single"}) singleIter := singlePool.NewIterator() if k, ok := singleIter.Next(); !ok || k != "single" { t.Errorf("expected single, got %s (ok=%v)", k, ok) } if _, ok := singleIter.Next(); ok { t.Errorf("expected exhausted after single key") } } func TestWebTool_TavilySearch_Failover(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var payload map[string]any if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { t.Fatalf("failed to decode payload: %v", err) } apiKey := payload["api_key"].(string) if apiKey == "key1" { w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte("Rate limited")) return } if apiKey == "key2" { // Success response := map[string]any{ "results": []map[string]any{ { "title": "Success Result", "url": "https://example.com/success", "content": "Success content", }, }, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(response) return } w.WriteHeader(http.StatusBadRequest) })) defer server.Close() tool, err := NewWebSearchTool(WebSearchToolOptions{ TavilyEnabled: true, TavilyAPIKeys: []string{"key1", "key2"}, TavilyBaseURL: server.URL, TavilyMaxResults: 5, }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } ctx := context.Background() args := map[string]any{ "query": "test query", } result := tool.Execute(ctx, args) if result.IsError { t.Errorf("Expected success, got Error: %s", result.ForLLM) } if !strings.Contains(result.ForUser, "Success Result") { t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser) } } func TestWebTool_GLMSearch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) } if r.Header.Get("Content-Type") != "application/json" { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } if r.Header.Get("Authorization") != "Bearer test-glm-key" { t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization")) } var payload map[string]any json.NewDecoder(r.Body).Decode(&payload) if payload["search_query"] != "test query" { t.Errorf("Expected search_query 'test query', got %v", payload["search_query"]) } if payload["search_engine"] != "search_std" { t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"]) } response := map[string]any{ "id": "web-search-test", "created": 1709568000, "search_result": []map[string]any{ { "title": "Test GLM Result", "content": "GLM search snippet", "link": "https://example.com/glm", "media": "Example", "publish_date": "2026-03-04", }, }, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(response) })) defer server.Close() tool, err := NewWebSearchTool(WebSearchToolOptions{ GLMSearchEnabled: true, GLMSearchAPIKey: "test-glm-key", GLMSearchBaseURL: server.URL, GLMSearchEngine: "search_std", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "query": "test query", }) if result.IsError { t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) } if !strings.Contains(result.ForUser, "Test GLM Result") { t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser) } if !strings.Contains(result.ForUser, "https://example.com/glm") { t.Errorf("Expected URL in output, got: %s", result.ForUser) } if !strings.Contains(result.ForUser, "via GLM Search") { t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser) } } func TestWebTool_GLMSearch_APIError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":"invalid api key"}`)) })) defer server.Close() tool, err := NewWebSearchTool(WebSearchToolOptions{ GLMSearchEnabled: true, GLMSearchAPIKey: "bad-key", GLMSearchBaseURL: server.URL, GLMSearchEngine: "search_std", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } result := tool.Execute(context.Background(), map[string]any{ "query": "test query", }) if !result.IsError { t.Errorf("Expected IsError=true for 401 response") } if !strings.Contains(result.ForLLM, "status 401") { t.Errorf("Expected status 401 in error, got: %s", result.ForLLM) } } func TestWebTool_GLMSearch_Priority(t *testing.T) { // GLM Search should only be selected when all other providers are disabled tool, err := NewWebSearchTool(WebSearchToolOptions{ DuckDuckGoEnabled: true, DuckDuckGoMaxResults: 5, GLMSearchEnabled: true, GLMSearchAPIKey: "test-key", GLMSearchBaseURL: "https://example.com", GLMSearchEngine: "search_std", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } // DuckDuckGo should win over GLM Search if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok { t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider) } // With DuckDuckGo disabled, GLM Search should be selected tool2, err := NewWebSearchTool(WebSearchToolOptions{ DuckDuckGoEnabled: false, GLMSearchEnabled: true, GLMSearchAPIKey: "test-key", GLMSearchBaseURL: "https://example.com", GLMSearchEngine: "search_std", }) if err != nil { t.Fatalf("NewWebSearchTool() error: %v", err) } if _, ok := tool2.provider.(*GLMSearchProvider); !ok { t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider) } } ================================================ FILE: pkg/utils/bm25.go ================================================ // Package utils provides shared, reusable algorithms. // This file implements a generic BM25 search engine. // // Usage: // // type MyDoc struct { ID string; Body string } // // corpus := []MyDoc{...} // engine := bm25.New(corpus, func(d MyDoc) string { // return d.ID + " " + d.Body // }) // results := engine.Search("my query", 5) package utils import ( "math" "sort" "strings" ) // ── Tuning defaults ─────────────────────────────────────────────────────────── const ( // DefaultBM25K1 is the term-frequency saturation factor (typical range 1.2–2.0). // Higher values give more weight to repeated terms. DefaultBM25K1 = 1.2 // DefaultBM25B is the document-length normalization factor (0 = none, 1 = full). DefaultBM25B = 0.75 ) // BM25Engine is a query-time BM25 search engine over a generic corpus. // T is the document type; the caller supplies a TextFunc that extracts the // searchable text from each document. // // The engine is stateless between queries: no caching, no invalidation logic. // All indexing work is performed inside Search() on every call, making it // safe to use on corpora that change frequently. type BM25Engine[T any] struct { corpus []T textFunc func(T) string k1 float64 b float64 } // BM25Option is a functional option to configure a BM25Engine. type BM25Option func(*bm25Config) type bm25Config struct { k1 float64 b float64 } // WithK1 overrides the term-frequency saturation constant (default 1.2). func WithK1(k1 float64) BM25Option { return func(c *bm25Config) { c.k1 = k1 } } // WithB overrides the document-length normalization factor (default 0.75). func WithB(b float64) BM25Option { return func(c *bm25Config) { c.b = b } } // NewBM25Engine creates a BM25Engine for the given corpus. // // - corpus : slice of documents of any type T. // - textFunc : function that returns the searchable text for a document. // - opts : optional tuning (WithK1, WithB). // // The corpus slice is referenced, not copied. Callers must not mutate it // concurrently with Search(). func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Option) *BM25Engine[T] { cfg := bm25Config{k1: DefaultBM25K1, b: DefaultBM25B} for _, o := range opts { o(&cfg) } return &BM25Engine[T]{ corpus: corpus, textFunc: textFunc, k1: cfg.k1, b: cfg.b, } } // BM25Result is a single ranked result from a Search call. type BM25Result[T any] struct { Document T Score float32 } // Search ranks the corpus against query and returns the top-k results. // Returns an empty slice (not nil) when there are no matches. // // Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring, // where N = corpus size, L = average document length, Q = query terms. // Top-k extraction uses a fixed-size min-heap: O(candidates × log k). func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] { if topK <= 0 { return []BM25Result[T]{} } queryTerms := bm25Tokenize(query) if len(queryTerms) == 0 { return []BM25Result[T]{} } N := len(e.corpus) if N == 0 { return []BM25Result[T]{} } // Step 1: build per-document tf + raw doc lengths type docEntry struct { tf map[string]uint32 rawLen int } entries := make([]docEntry, N) df := make(map[string]int, 64) totalLen := 0 for i, doc := range e.corpus { tokens := bm25Tokenize(e.textFunc(doc)) totalLen += len(tokens) tf := make(map[string]uint32, len(tokens)) for _, t := range tokens { tf[t]++ } // df: each term counts once per document (iterate the map, keys are unique) for t := range tf { df[t]++ } entries[i] = docEntry{tf: tf, rawLen: len(tokens)} } avgDocLen := float64(totalLen) / float64(N) // Step 2: pre-compute IDF and per-doc length normalization // IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 ) idf := make(map[string]float32, len(df)) for term, freq := range df { idf[term] = float32(math.Log( (float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1, )) } // docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen) // Stored as float32 — sufficient precision for ranking. docLenNorm := make([]float32, N) for i, entry := range entries { docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen)) } // Step 3: build inverted index (posting lists) // Iterate the tf map directly — map keys are already unique, no seen-set needed. posting := make(map[string][]int32, len(df)) for i, entry := range entries { for term := range entry.tf { posting[term] = append(posting[term], int32(i)) } } // Step 4: score via posting lists // Deduplicate query terms to avoid double-weighting the same term. unique := bm25Dedupe(queryTerms) scores := make(map[int32]float32) for _, term := range unique { termIDF, ok := idf[term] if !ok { continue // term not in vocabulary → zero contribution } for _, docID := range posting[term] { freq := float32(entries[docID].tf[term]) // TF_norm = freq * (k1+1) / (freq + docLenNorm) tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID]) scores[docID] += termIDF * tfNorm } } if len(scores) == 0 { return []BM25Result[T]{} } // Step 5: top-K via fixed-size min-heap heap := make([]bm25ScoredDoc, 0, topK) for docID, sc := range scores { switch { case len(heap) < topK: heap = append(heap, bm25ScoredDoc{docID: docID, score: sc}) if len(heap) == topK { bm25MinHeapify(heap) } case sc > heap[0].score: heap[0] = bm25ScoredDoc{docID: docID, score: sc} bm25SiftDown(heap, 0) } } sort.Slice(heap, func(i, j int) bool { return heap[i].score > heap[j].score }) out := make([]BM25Result[T], len(heap)) for i, h := range heap { out[i] = BM25Result[T]{ Document: e.corpus[h.docID], Score: h.score, } } return out } // bm25Tokenize splits s into lowercase tokens, stripping edge punctuation. func bm25Tokenize(s string) []string { raw := strings.Fields(strings.ToLower(s)) out := raw[:0] // reuse backing array to avoid extra allocation for _, t := range raw { t = strings.Trim(t, ".,;:!?\"'()/\\-_") if t != "" { out = append(out, t) } } return out } // bm25Dedupe returns a new slice with duplicate tokens removed, // preserving first-occurrence order. func bm25Dedupe(tokens []string) []string { seen := make(map[string]struct{}, len(tokens)) out := make([]string, 0, len(tokens)) for _, t := range tokens { if _, ok := seen[t]; !ok { seen[t] = struct{}{} out = append(out, t) } } return out } type bm25ScoredDoc struct { docID int32 score float32 } // bm25MinHeapify builds a min-heap in-place using Floyd's algorithm: O(k). func bm25MinHeapify(h []bm25ScoredDoc) { for i := len(h)/2 - 1; i >= 0; i-- { bm25SiftDown(h, i) } } // bm25SiftDown restores the min-heap property starting at node i: O(log k). func bm25SiftDown(h []bm25ScoredDoc, i int) { n := len(h) for { smallest := i l, r := 2*i+1, 2*i+2 if l < n && h[l].score < h[smallest].score { smallest = l } if r < n && h[r].score < h[smallest].score { smallest = r } if smallest == i { break } h[i], h[smallest] = h[smallest], h[i] i = smallest } } ================================================ FILE: pkg/utils/bm25_test.go ================================================ package utils import ( "reflect" "testing" ) // testDoc is a generic structure for use in tests. type testDoc struct { ID int Text string } func extractText(d testDoc) string { return d.Text } func TestBM25Search_EdgeCases(t *testing.T) { corpus := []testDoc{ {1, "hello world"}, {2, "foo bar"}, } engine := NewBM25Engine(corpus, extractText) tests := []struct { name string query string topK int }{ {"Zero topK", "hello", 0}, {"Negative topK", "hello", -1}, {"Empty query", "", 5}, {"Query with only punctuation", "...,,,!!!", 5}, {"No matches found", "golang", 5}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { results := engine.Search(tt.query, tt.topK) if len(results) != 0 { t.Errorf("expected 0 results, got %d", len(results)) } // Check that it never returns nil, but an empty slice if results == nil { t.Errorf("expected empty slice, got nil") } }) } } func TestBM25Search_EmptyCorpus(t *testing.T) { engine := NewBM25Engine([]testDoc{}, extractText) results := engine.Search("hello", 5) if len(results) != 0 || results == nil { t.Errorf("expected empty slice from empty corpus, got %v", results) } } func TestBM25Search_RankingLogic(t *testing.T) { corpus := []testDoc{ {1, "the quick brown fox jumps over the lazy dog"}, {2, "quick fox"}, {3, "quick quick quick fox"}, // High Term Frequency (TF) {4, "completely irrelevant document here"}, } engine := NewBM25Engine(corpus, extractText) t.Run("Term Frequency (TF) boosts score", func(t *testing.T) { results := engine.Search("quick", 5) if len(results) < 3 { t.Fatalf("expected at least 3 results, got %d", len(results)) } // Doc 3 has the word "quick" repeated 3 times, it should beat Doc 2 if results[0].Document.ID != 3 { t.Errorf("expected doc 3 to rank first due to high TF, got doc %d", results[0].Document.ID) } }) t.Run("Document Length penalty", func(t *testing.T) { results := engine.Search("fox", 5) if len(results) < 3 { t.Fatalf("expected at least 3 results, got %d", len(results)) } // Doc 2 ("quick fox") is much shorter than Doc 1 ("the quick brown fox..."), // so, with equal Term Frequency for the word "fox" (1 time), Doc 2 wins. if results[0].Document.ID != 2 { t.Errorf("expected doc 2 to rank first due to shorter length, got doc %d", results[0].Document.ID) } }) t.Run("TopK limits results", func(t *testing.T) { results := engine.Search("quick", 2) if len(results) != 2 { t.Errorf("expected exactly 2 results, got %d", len(results)) } }) } func TestBM25Tokenize(t *testing.T) { tests := []struct { input string expected []string }{ {"Hello World", []string{"hello", "world"}}, {" spaces everywhere ", []string{"spaces", "everywhere"}}, {"punctuation... test!!!", []string{"punctuation", "test"}}, {"(parentheses) and-hyphens", []string{"parentheses", "and-hyphens"}}, // hyphens trimmed from edges {"internal-hyphen is kept", []string{"internal-hyphen", "is", "kept"}}, {".,;?!", []string{}}, // Becomes empty after trim } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got := bm25Tokenize(tt.input) if len(got) == 0 && len(tt.expected) == 0 { return // Both empty } if !reflect.DeepEqual(got, tt.expected) { t.Errorf("bm25Tokenize(%q) = %v, want %v", tt.input, got, tt.expected) } }) } } func TestBM25Dedupe(t *testing.T) { input := []string{"apple", "banana", "apple", "orange", "banana"} expected := []string{"apple", "banana", "orange"} got := bm25Dedupe(input) if !reflect.DeepEqual(got, expected) { t.Errorf("bm25Dedupe() = %v, want %v", got, expected) } } func TestBM25Options(t *testing.T) { corpus := []testDoc{{1, "test"}} engine := NewBM25Engine( corpus, extractText, WithK1(2.5), WithB(0.9), ) if engine.k1 != 2.5 { t.Errorf("expected k1 to be 2.5, got %v", engine.k1) } if engine.b != 0.9 { t.Errorf("expected b to be 0.9, got %v", engine.b) } } func TestBM25Search_SortingStability(t *testing.T) { // Ensure that sorting by heap returns in correct descending order corpus := []testDoc{ {1, "golang is good"}, {2, "golang golang"}, {3, "golang golang golang"}, {4, "golang golang golang golang"}, } engine := NewBM25Engine(corpus, extractText) results := engine.Search("golang", 10) if len(results) != 4 { t.Fatalf("expected 4 results, got %d", len(results)) } // Score should be strictly decreasing for i := 1; i < len(results); i++ { if results[i].Score > results[i-1].Score { t.Errorf("results not sorted correctly: result %d score (%v) > result %d score (%v)", i, results[i].Score, i-1, results[i-1].Score) } } } ================================================ FILE: pkg/utils/download.go ================================================ package utils import ( "context" "fmt" "io" "net/http" "os" "github.com/sipeed/picoclaw/pkg/logger" ) // DownloadToFile streams an HTTP response body to a temporary file in small // chunks (~32KB), keeping peak memory usage constant regardless of file size. // // Parameters: // - ctx: context for cancellation/timeout // - client: HTTP client to use (caller controls timeouts, transport, etc.) // - req: fully prepared *http.Request (method, URL, headers, etc.) // - maxBytes: maximum bytes to download; 0 means no limit // // Returns the path to the temporary file. The caller is responsible for // removing it when done (defer os.Remove(path)). // // On any error the temp file is cleaned up automatically. func DownloadToFile(ctx context.Context, client *http.Client, req *http.Request, maxBytes int64) (string, error) { // Attach context. req = req.WithContext(ctx) logger.DebugCF("download", "Starting download", map[string]any{ "url": req.URL.String(), "max_bytes": maxBytes, }) resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { // Read a small amount for the error message. errBody := make([]byte, 512) n, _ := io.ReadFull(resp.Body, errBody) return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n])) } // Create temp file. tmpFile, err := os.CreateTemp("", "picoclaw-dl-*") if err != nil { return "", fmt.Errorf("failed to create temp file: %w", err) } tmpPath := tmpFile.Name() logger.DebugCF("download", "Streaming to temp file", map[string]any{ "path": tmpPath, }) // Cleanup helper — removes the temp file on any error. cleanup := func() { _ = tmpFile.Close() _ = os.Remove(tmpPath) } // Optionally limit the download size. var src io.Reader = resp.Body if maxBytes > 0 { src = io.LimitReader(resp.Body, maxBytes+1) // +1 to detect overflow } written, err := io.Copy(tmpFile, src) if err != nil { cleanup() return "", fmt.Errorf("download write failed: %w", err) } if maxBytes > 0 && written > maxBytes { cleanup() return "", fmt.Errorf("download too large: %d bytes (max %d)", written, maxBytes) } if err := tmpFile.Close(); err != nil { _ = os.Remove(tmpPath) return "", fmt.Errorf("failed to close temp file: %w", err) } logger.DebugCF("download", "Download complete", map[string]any{ "path": tmpPath, "bytes_written": written, }) return tmpPath, nil } ================================================ FILE: pkg/utils/http_client.go ================================================ package utils import ( "fmt" "net/http" "net/url" "strings" "time" ) // CreateHTTPClient creates an HTTP client with optional proxy support. // If proxyURL is empty, it uses the system environment proxy settings. // Supported proxy schemes: http, https, socks5, socks5h. func CreateHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { client := &http.Client{ Timeout: timeout, Transport: &http.Transport{ MaxIdleConns: 10, IdleConnTimeout: 30 * time.Second, DisableCompression: false, TLSHandshakeTimeout: 15 * time.Second, }, } if proxyURL != "" { proxy, err := url.Parse(proxyURL) if err != nil { return nil, fmt.Errorf("invalid proxy URL: %w", err) } scheme := strings.ToLower(proxy.Scheme) switch scheme { case "http", "https", "socks5", "socks5h": default: return nil, fmt.Errorf( "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)", proxy.Scheme, ) } if proxy.Host == "" { return nil, fmt.Errorf("invalid proxy URL: missing host") } client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy) } else { client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment } return client, nil } ================================================ FILE: pkg/utils/http_client_test.go ================================================ package utils import ( "net/http" "strings" "testing" "time" ) func TestCreateHTTPClient_ProxyConfigured(t *testing.T) { client, err := CreateHTTPClient("http://127.0.0.1:7890", 12*time.Second) if err != nil { t.Fatalf("createHTTPClient() error: %v", err) } if client.Timeout != 12*time.Second { t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second) } tr, ok := client.Transport.(*http.Transport) if !ok { t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) } if tr.Proxy == nil { t.Fatal("transport.Proxy is nil, want non-nil") } req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest() error: %v", err) } proxyURL, err := tr.Proxy(req) if err != nil { t.Fatalf("transport.Proxy(req) error: %v", err) } if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" { t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890") } } func TestCreateHTTPClient_InvalidProxy(t *testing.T) { _, err := CreateHTTPClient("://bad-proxy", 10*time.Second) if err == nil { t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil") } } func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) { client, err := CreateHTTPClient("socks5://127.0.0.1:1080", 8*time.Second) if err != nil { t.Fatalf("createHTTPClient() error: %v", err) } tr, ok := client.Transport.(*http.Transport) if !ok { t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) } req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest() error: %v", err) } proxyURL, err := tr.Proxy(req) if err != nil { t.Fatalf("transport.Proxy(req) error: %v", err) } if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" { t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080") } } func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) { _, err := CreateHTTPClient("ftp://127.0.0.1:21", 10*time.Second) if err == nil { t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil") } if !strings.Contains(err.Error(), "unsupported proxy scheme") { t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme") } } func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") t.Setenv("http_proxy", "http://127.0.0.1:8888") t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") t.Setenv("https_proxy", "http://127.0.0.1:8888") t.Setenv("ALL_PROXY", "") t.Setenv("all_proxy", "") t.Setenv("NO_PROXY", "") t.Setenv("no_proxy", "") client, err := CreateHTTPClient("", 10*time.Second) if err != nil { t.Fatalf("createHTTPClient() error: %v", err) } tr, ok := client.Transport.(*http.Transport) if !ok { t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) } if tr.Proxy == nil { t.Fatal("transport.Proxy is nil, want proxy function from environment") } req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest() error: %v", err) } if _, err := tr.Proxy(req); err != nil { t.Fatalf("transport.Proxy(req) error: %v", err) } } ================================================ FILE: pkg/utils/http_retry.go ================================================ package utils import ( "context" "fmt" "net/http" "time" ) const maxRetries = 3 var retryDelayUnit = time.Second func shouldRetry(statusCode int) bool { return statusCode == http.StatusTooManyRequests || statusCode >= 500 } func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response, error) { var resp *http.Response var err error for i := range maxRetries { if i > 0 && resp != nil { resp.Body.Close() } resp, err = client.Do(req) if err == nil { if resp.StatusCode == http.StatusOK { break } if !shouldRetry(resp.StatusCode) { break } } if i < maxRetries-1 { if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil { if resp != nil { resp.Body.Close() } return nil, fmt.Errorf("failed to sleep: %w", err) } } } return resp, err } func sleepWithCtx(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) defer timer.Stop() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } } ================================================ FILE: pkg/utils/http_retry_test.go ================================================ package utils import ( "context" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDoRequestWithRetry(t *testing.T) { retryDelayUnit = time.Millisecond t.Cleanup(func() { retryDelayUnit = time.Second }) testcases := []struct { name string serverBehavior func(*httptest.Server) int wantSuccess bool wantAttempts int }{ { name: "success-on-first-attempt", serverBehavior: func(server *httptest.Server) int { return 0 }, wantSuccess: true, wantAttempts: 1, }, { name: "fail-all-attempts", serverBehavior: func(server *httptest.Server) int { return 4 }, wantSuccess: false, wantAttempts: 3, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { attempts := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts <= tc.serverBehavior(nil) { w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Write([]byte("success")) })) t.Cleanup(func() { server.Close() }) client := &http.Client{Timeout: 5 * time.Second} req, err := http.NewRequest(http.MethodGet, server.URL, nil) require.NoError(t, err) resp, err := DoRequestWithRetry(client, req) if tc.wantSuccess { require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, http.StatusOK, resp.StatusCode) resp.Body.Close() } else { require.NotNil(t, resp) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) resp.Body.Close() } assert.Equal(t, tc.wantAttempts, attempts) }) } } func TestDoRequestWithRetry_ContextCancel(t *testing.T) { // Use a long retry delay so cancellation always hits during sleepWithCtx. retryDelayUnit = 10 * time.Second t.Cleanup(func() { retryDelayUnit = time.Second }) bodyClosed := false firstRoundTripDone := make(chan struct{}, 1) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("error")) })) defer server.Close() client := server.Client() client.Timeout = 30 * time.Second client.Transport = &bodyCloseTracker{ rt: client.Transport, onClose: func() { bodyClosed = true }, // Signal after the first round-trip response is fully constructed on the client side. onRoundTrip: func() { select { case firstRoundTripDone <- struct{}{}: default: } }, trackURL: server.URL, } ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Cancel the context after the first round-trip completes on the client side. // This ensures client.Do has returned a valid resp (with body) and the retry // loop is about to enter sleepWithCtx, where the cancel will be detected. go func() { <-firstRoundTripDone cancel() }() req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) require.NoError(t, err) resp, err := DoRequestWithRetry(client, req) if resp != nil { resp.Body.Close() } require.Error(t, err, "expected error from context cancellation") assert.Nil(t, resp, "expected nil response when context is canceled") assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation") } // bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed. type bodyCloseTracker struct { rt http.RoundTripper onClose func() onRoundTrip func() // called after each successful round-trip trackURL string } func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) { resp, err := t.rt.RoundTrip(req) if err != nil { return resp, err } if strings.HasPrefix(req.URL.String(), t.trackURL) { resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose} if t.onRoundTrip != nil { t.onRoundTrip() } } return resp, nil } // closeNotifier wraps an io.ReadCloser to detect Close calls. type closeNotifier struct { io.ReadCloser onClose func() } func (c *closeNotifier) Close() error { c.onClose() return c.ReadCloser.Close() } func TestDoRequestWithRetry_Delay(t *testing.T) { retryDelayUnit = time.Millisecond t.Cleanup(func() { retryDelayUnit = time.Second }) var start time.Time delays := []time.Duration{} server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if len(delays) == 0 { delays = append(delays, 0) w.WriteHeader(http.StatusInternalServerError) return } if len(delays) == 1 { start = time.Now() delays = append(delays, 0) w.WriteHeader(http.StatusInternalServerError) return } if len(delays) == 2 { elapsed := time.Since(start) delays = append(delays, elapsed) w.WriteHeader(http.StatusOK) w.Write([]byte("success")) } })) defer server.Close() client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest(http.MethodGet, server.URL, nil) require.NoError(t, err) resp, err := DoRequestWithRetry(client, req) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, http.StatusOK, resp.StatusCode) resp.Body.Close() assert.GreaterOrEqual(t, delays[2], time.Millisecond) } ================================================ FILE: pkg/utils/markdown.go ================================================ package utils import ( "bytes" "net/url" "regexp" "strconv" "strings" "golang.org/x/net/html" ) var ( reSpaces = regexp.MustCompile(`[ \t]+`) reNewlines = regexp.MustCompile(`\n{3,}`) reEmptyListItem = regexp.MustCompile(`(?m)^[-*]\s*$`) reImageOnlyLink = regexp.MustCompile(`\[!\[\]\(<[^>]*>\)\]\(<[^>]*>\)`) reEmptyHeader = regexp.MustCompile(`(?m)^#{1,6}\s*$`) reLeadingLineSpace = regexp.MustCompile(`(?m)^([ \t])([^ \t\n])`) ) var skipTags = map[string]bool{ "script": true, "style": true, "head": true, "noscript": true, "template": true, "nav": true, "footer": true, "aside": true, "header": true, "form": true, "dialog": true, } func isSafeHref(href string) bool { lower := strings.ToLower(strings.TrimSpace(href)) if strings.HasPrefix(lower, "javascript:") || strings.HasPrefix(lower, "vbscript:") || strings.HasPrefix(lower, "data:") { return false } u, err := url.Parse(strings.TrimSpace(href)) if err != nil { return false } scheme := strings.ToLower(u.Scheme) return scheme == "" || scheme == "http" || scheme == "https" || scheme == "mailto" } func isSafeImageSrc(src string) bool { lower := strings.ToLower(strings.TrimSpace(src)) if strings.HasPrefix(lower, "data:image/") { return true } return isSafeHref(src) } func escapeMdAlt(s string) string { s = strings.ReplaceAll(s, `\`, `\\`) s = strings.ReplaceAll(s, `[`, `\[`) s = strings.ReplaceAll(s, `]`, `\]`) return s } func getAttr(n *html.Node, key string) string { for _, a := range n.Attr { if a.Key == key { return a.Val } } return "" } func normalizeAttr(val string) string { val = strings.ReplaceAll(val, "\n", "") val = strings.ReplaceAll(val, "\r", "") val = strings.ReplaceAll(val, "\t", "") return strings.TrimSpace(val) } func isUnlikelyNode(n *html.Node) bool { if n.Type != html.ElementNode { return false } classId := strings.ToLower(getAttr(n, "class") + " " + getAttr(n, "id")) if classId == " " { return false } if strings.Contains(classId, "article") || strings.Contains(classId, "main") || strings.Contains(classId, "content") { return false } unlikelyKeywords := []string{ "menu", "nav", "footer", "sidebar", "cookie", "banner", "sponsor", "advert", "popup", "modal", "newsletter", "share", "social", } for _, keyword := range unlikelyKeywords { if strings.Contains(classId, keyword) { return true } } return false } type converter struct { stack []*bytes.Buffer linkHrefs []string linkStates []bool emphStack []string // Tracks "**", "*", "~~" for buffered emphasis olCounters []int inPre bool listDepth int } func newConverter() *converter { return &converter{ stack: []*bytes.Buffer{{}}, } } func (c *converter) write(s string) { c.stack[len(c.stack)-1].WriteString(s) } func (c *converter) pushBuf() { c.stack = append(c.stack, &bytes.Buffer{}) } func (c *converter) popBuf() string { top := c.stack[len(c.stack)-1] c.stack = c.stack[:len(c.stack)-1] return top.String() } func (c *converter) walk(n *html.Node) { if n.Type == html.ElementNode { if skipTags[n.Data] { return } if isUnlikelyNode(n) { return } } if n.Type == html.TextNode { text := n.Data if !c.inPre { text = strings.ReplaceAll(text, "\n", " ") text = reSpaces.ReplaceAllString(text, " ") } if text != "" { c.write(text) } return } if n.Type != html.ElementNode { for ch := n.FirstChild; ch != nil; ch = ch.NextSibling { c.walk(ch) } return } // Opening Tags switch n.Data { // Buffer emphasis content so we can TrimSpace the inner text, // avoiding the regex-across-boundaries bug. case "b", "strong": c.emphStack = append(c.emphStack, "**") c.pushBuf() case "i", "em": c.emphStack = append(c.emphStack, "*") c.pushBuf() case "del", "s": c.emphStack = append(c.emphStack, "~~") c.pushBuf() case "a": href := normalizeAttr(getAttr(n, "href")) if href != "" && !isSafeHref(href) { href = "#" } hasHref := href != "" c.linkStates = append(c.linkStates, hasHref) if hasHref { c.linkHrefs = append(c.linkHrefs, href) c.pushBuf() } case "h1": c.write("\n\n# ") case "h2": c.write("\n\n## ") case "h3": c.write("\n\n### ") case "h4": c.write("\n\n#### ") case "h5": c.write("\n\n##### ") case "h6": c.write("\n\n###### ") case "p": c.write("\n\n") case "br": c.write("\n") case "hr": c.write("\n\n---\n\n") case "ol": c.olCounters = append(c.olCounters, 1) // Only write leading newline for top-level list. if c.listDepth == 0 { c.write("\n") } c.listDepth++ case "ul": if c.listDepth == 0 { c.write("\n") } c.listDepth++ case "li": c.write("\n") if c.listDepth > 1 { c.write(strings.Repeat(" ", c.listDepth-1)) } if n.Parent != nil && n.Parent.Data == "ol" && len(c.olCounters) > 0 { idx := c.olCounters[len(c.olCounters)-1] c.write(strconv.Itoa(idx) + ". ") c.olCounters[len(c.olCounters)-1]++ } else { c.write("- ") } case "pre": c.inPre = true c.write("\n\n```\n") case "code": if !c.inPre { c.write("`") } case "blockquote": c.pushBuf() for ch := n.FirstChild; ch != nil; ch = ch.NextSibling { c.walk(ch) } inner := strings.TrimSpace(c.popBuf()) lines := strings.Split(inner, "\n") var quoted []string for _, l := range lines { if strings.TrimSpace(l) == "" { quoted = append(quoted, ">") } else { quoted = append(quoted, "> "+l) } } var deduped []string for i, line := range quoted { if line == ">" && i > 0 && deduped[len(deduped)-1] == ">" { continue } deduped = append(deduped, line) } c.write("\n\n" + strings.Join(deduped, "\n") + "\n\n") return case "img": src := normalizeAttr(getAttr(n, "src")) if src == "" { src = normalizeAttr(getAttr(n, "data-src")) } if src == "" { return } alt := escapeMdAlt(normalizeAttr(getAttr(n, "alt"))) if isSafeImageSrc(src) { c.write("") } return } // Traverse Children for ch := n.FirstChild; ch != nil; ch = ch.NextSibling { c.walk(ch) } // Closing Tags switch n.Data { // Pop buffer, trim, wrap with the correct marker. case "b", "strong", "i", "em", "del", "s": if len(c.emphStack) == 0 { break } marker := c.emphStack[len(c.emphStack)-1] c.emphStack = c.emphStack[:len(c.emphStack)-1] inner := strings.TrimSpace(c.popBuf()) if inner != "" { c.write(marker + inner + marker) } case "a": if len(c.linkStates) == 0 { break } hasHref := c.linkStates[len(c.linkStates)-1] c.linkStates = c.linkStates[:len(c.linkStates)-1] if !hasHref { break } href := c.linkHrefs[len(c.linkHrefs)-1] c.linkHrefs = c.linkHrefs[:len(c.linkHrefs)-1] inner := strings.TrimSpace(c.popBuf()) if strings.Contains(inner, "\n") { lines := strings.Split(inner, "\n") linked := false for i, l := range lines { cleanLine := strings.TrimSpace(l) if cleanLine != "" && !strings.HasPrefix(cleanLine, "![") && !linked { lines[i] = "[" + cleanLine + "](" + href + ")" linked = true } } c.write(strings.Join(lines, "\n")) } else { c.write("[" + inner + "](" + href + ")") } case "h1", "h2", "h3", "h4", "h5", "h6", "p", "div", "section", "article", "header", "footer", "aside", "nav", "figure": c.write("\n") case "ol": c.listDepth-- if len(c.olCounters) > 0 { c.olCounters = c.olCounters[:len(c.olCounters)-1] } if c.listDepth == 0 { c.write("\n") } case "ul": c.listDepth-- if c.listDepth == 0 { c.write("\n") } case "pre": c.inPre = false c.write("\n```\n\n") case "code": if !c.inPre { c.write("`") } } } func HtmlToMarkdown(htmlStr string) (string, error) { doc, err := html.Parse(strings.NewReader(htmlStr)) if err != nil { return "", err } c := newConverter() c.walk(doc) res := c.stack[0].String() // Post-processing res = reImageOnlyLink.ReplaceAllString(res, "") res = reEmptyListItem.ReplaceAllString(res, "") res = reEmptyHeader.ReplaceAllString(res, "") lines := strings.Split(res, "\n") var cleanLines []string for _, line := range lines { line = strings.TrimRight(line, " \t") cleanTest := strings.TrimSpace(line) if cleanTest == "[](>)" || cleanTest == "[](#)" || cleanTest == "-" { cleanLines = append(cleanLines, "") continue } cleanLines = append(cleanLines, line) } res = strings.Join(cleanLines, "\n") res = strings.TrimSpace(res) res = reNewlines.ReplaceAllString(res, "\n\n") // Strip a single leading space from lines that are NOT list indentation. // "(?m)^([ \t])([^ \t\n])" matches exactly one space/tab at line start followed // by a non-whitespace char, so " - nested" (4 spaces) is left untouched. res = reLeadingLineSpace.ReplaceAllString(res, "$2") return res, nil } ================================================ FILE: pkg/utils/markdown_test.go ================================================ package utils import ( "testing" "github.com/sipeed/picoclaw/pkg/logger" ) func TestHtmlToMarkdown(t *testing.T) { // Define our test cases tests := []struct { name string input string expected string }{ { name: "Removes scripts and styles", input: `Clean text
`, expected: "Clean text", }, { name: "Extracts links correctly", input: `Visit my website for info.`, expected: "Visit my [website](https://example.com) for info.", }, { name: "Converts headers (H1, H2, H3)", input: `First paragraph
Second paragraph with
a line break.
`,
// Correct Markdown syntax for images
expected: "",
},
{
name: "Image support without alt-text",
input: `
`,
// If alt is missing, square brackets remain empty
expected: "",
},
{
name: "XSS Bypass on Links (Obfuscated HTML entities)",
// The Go HTML parser resolves entities, so this becomes "javascript:alert(1)"
input: `Click here`,
// Our isSafeHref (if updated with net/url) should neutralize it to "#"
expected: "[Click here](#)",
},
{
name: "Empty link or used as anchor",
input: ``,
// With no text or href, it shouldn't print anything (not even empty brackets)
expected: "",
},
{
name: "Link without href but with text (Textual anchor)",
input: `Back to top`,
// Should extract only plain text, without generating a broken Markdown link like [Back to top](#) or [Back to top]()
expected: "Back to top",
},
{
name: "Badly spaced bold and italics (Edge Case)",
input: ` Text `,
// In Markdown `** Text **` is often not formatted correctly. The ideal is `**Text**`
expected: "**Text**",
},
{
name: "Complex Test - Real Article",
input: `
This is an introductory text with a link.
func main() {\n fmt.Println(\"hello\")\n}",
expected: "```\nfunc main() {\n fmt.Println(\"hello\")\n}\n```",
},
{
name: "Inline code",
input: `Use the command go test ./... to run the tests.
`, expected: "> An important quote.", }, { name: "Multiline blockquote", input: `An important quote.
`, expected: "> First line of the quote.\n>\n> Second line of the quote.", }, { name: "Strikethrough text (del/s)", input: `This text isFirst line of the quote.
Second line of the quote.
Above the line
Below the line
`, expected: "Above the line\n\n---\n\nBelow the line", }, { name: "Bold nested in link", input: `Linked bold text`, expected: "[**Linked bold text**](https://example.com)", }, { name: "data-src Image (lazy loading)", input: `Deeply nested text
Important: read the critical instructions carefully.
`, expected: "**Important:** read the ***critical instructions*** *carefully*.", }, { name: "Article with nav and aside sections (noise to filter)", input: `This is the body of the article.
`,
// The image-link without text must not generate broken markup
expected: "[](https://example.com)",
},
{
name: "Empty content or only whitespace",
input: `
%s
You can close this window.
{t("channels.edit", { name: channelDisplayName, })}
{channel && docsUrl && ( {t("channels.page.docLink")} )}{t("channels.page.enableLabel")}
{serverError}
)}{t("chat.empty.noConfiguredModelDescription")}
{t("chat.empty.noSelectedModelDescription")}
{t("chat.empty.notRunningDescription")}
{t("chat.welcomeDesc")}
{thinkingSteps[stepIndex]}
`,
expected: "",
},
{
name: "Image with javascript: src blocked",
input: `