Repository: CNTRLAI/Notate Branch: main Commit: bc5c346ab1d2 Files: 298 Total size: 1002.2 KB Directory structure: gitextract_jmk8mdb7/ ├── .gitignore ├── Backend/ │ ├── .gitignore │ ├── ensure_dependencies.py │ ├── main.py │ ├── requirements.txt │ ├── src/ │ │ ├── authentication/ │ │ │ ├── api_key_authorization.py │ │ │ └── token.py │ │ ├── data/ │ │ │ ├── dataFetch/ │ │ │ │ ├── webcrawler.py │ │ │ │ └── youtube.py │ │ │ ├── dataIntake/ │ │ │ │ ├── csvFallbackSplitting.py │ │ │ │ ├── fileTypes/ │ │ │ │ │ └── loadX.py │ │ │ │ ├── getHtmlFiles.py │ │ │ │ ├── loadFile.py │ │ │ │ └── textSplitting.py │ │ │ └── database/ │ │ │ ├── checkAPIKey.py │ │ │ ├── db.py │ │ │ ├── getCollectionInfo.py │ │ │ └── getLLMApiKey.py │ │ ├── endpoint/ │ │ │ ├── api.py │ │ │ ├── deleteStore.py │ │ │ ├── devApiCall.py │ │ │ ├── embed.py │ │ │ ├── models.py │ │ │ ├── ragQuery.py │ │ │ ├── transcribe.py │ │ │ ├── vectorQuery.py │ │ │ └── webcrawl.py │ │ ├── llms/ │ │ │ ├── llmQuery.py │ │ │ ├── messages/ │ │ │ │ └── formMessages.py │ │ │ └── providers/ │ │ │ ├── local.py │ │ │ ├── ollama.py │ │ │ ├── ooba.py │ │ │ └── openai.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── exceptions.py │ │ │ ├── loaders/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── exllama.py │ │ │ │ ├── hqq.py │ │ │ │ ├── llamaccphf.py │ │ │ │ ├── llamacpp.py │ │ │ │ ├── tensorrt.py │ │ │ │ └── transformers.py │ │ │ ├── manager.py │ │ │ ├── streamer.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── detect_type.py │ │ │ ├── device.py │ │ │ ├── download.py │ │ │ └── platform.py │ │ ├── vectorstorage/ │ │ │ ├── embeddings.py │ │ │ ├── helpers/ │ │ │ │ └── sanitizeCollectionName.py │ │ │ ├── init_store.py │ │ │ └── vectorstore.py │ │ └── voice/ │ │ └── voice_to_text.py │ └── tests/ │ ├── testApi.py │ └── test_voice.py ├── Frontend/ │ ├── .gitignore │ ├── build/ │ │ └── icons/ │ │ └── icon.icns │ ├── components.json │ ├── e2e/ │ │ └── app.spec.ts │ ├── electron-builder.json │ ├── eslint.config.js │ ├── index.html │ ├── package.json │ ├── playwright.config.ts │ ├── postcss.config.js │ ├── src/ │ │ ├── app/ │ │ │ ├── App.tsx │ │ │ ├── index.css │ │ │ ├── main.tsx │ │ │ └── vite-env.d.ts │ │ ├── components/ │ │ │ ├── AppAlert/ │ │ │ │ └── SettingsAlert.tsx │ │ │ ├── Authentication/ │ │ │ │ ├── CreateAccount.tsx │ │ │ │ └── SelectAccount.tsx │ │ │ ├── Chat/ │ │ │ │ ├── Chat.tsx │ │ │ │ └── ChatComponents/ │ │ │ │ ├── ChatHeader.tsx │ │ │ │ ├── ChatInput.tsx │ │ │ │ ├── ChatMessage.tsx │ │ │ │ ├── ChatMessagesArea.tsx │ │ │ │ ├── LoadingIndicator.tsx │ │ │ │ ├── NewConvoWelcome.tsx │ │ │ │ ├── ReasoningMessage.tsx │ │ │ │ ├── StreamingMessage.tsx │ │ │ │ ├── StreamingReasoningMessage.tsx │ │ │ │ ├── SyntaxHightlightedCode.tsx │ │ │ │ └── suggestions.tsx │ │ │ ├── CollectionModals/ │ │ │ │ ├── CollectionComponents/ │ │ │ │ │ ├── AddLibrary.tsx │ │ │ │ │ ├── DataStoreSelect.tsx │ │ │ │ │ ├── FIlesInCollection.tsx │ │ │ │ │ ├── Ingest.tsx │ │ │ │ │ ├── IngestProgress.tsx │ │ │ │ │ ├── IngestTabs/ │ │ │ │ │ │ ├── FileIngestTab.tsx │ │ │ │ │ │ └── LinkIngestTab.tsx │ │ │ │ │ └── ingestTypes.tsx │ │ │ │ └── LibraryModal.tsx │ │ │ ├── FileExplorer/ │ │ │ │ └── FileExplorer.tsx │ │ │ ├── Header/ │ │ │ │ ├── Header.tsx │ │ │ │ └── HeaderComponents/ │ │ │ │ ├── MainWindowControl.tsx │ │ │ │ ├── Search.tsx │ │ │ │ ├── SettingsDialog.tsx │ │ │ │ ├── ToolsDialog.tsx │ │ │ │ └── WinLinuxControls.tsx │ │ │ ├── History/ │ │ │ │ └── History.tsx │ │ │ ├── SettingsModal/ │ │ │ │ ├── SettingsComponents/ │ │ │ │ │ ├── ChatSettings.tsx │ │ │ │ │ ├── DevIntegration.tsx │ │ │ │ │ ├── LLMModels/ │ │ │ │ │ │ ├── AddLocalModel.tsx │ │ │ │ │ │ ├── AddOllamaModel.tsx │ │ │ │ │ │ ├── AzureOpenAI.tsx │ │ │ │ │ │ ├── CustomLLM.tsx │ │ │ │ │ │ ├── External.tsx │ │ │ │ │ │ ├── ExternalOllama.tsx │ │ │ │ │ │ ├── LocalLLM.tsx │ │ │ │ │ │ ├── Ollama.tsx │ │ │ │ │ │ └── Openrouter.tsx │ │ │ │ │ ├── LLMPanel.tsx │ │ │ │ │ └── providers/ │ │ │ │ │ ├── SvgIcon.tsx │ │ │ │ │ ├── defaultsProviderModels.tsx │ │ │ │ │ └── providerIcons.tsx │ │ │ │ └── SettingsModal.tsx │ │ │ ├── Tools/ │ │ │ │ ├── ToolComponents/ │ │ │ │ │ ├── AddTools.tsx │ │ │ │ │ └── EnableTools.tsx │ │ │ │ └── Tools.tsx │ │ │ └── ui/ │ │ │ ├── alert.tsx │ │ │ ├── avatar.tsx │ │ │ ├── badge.tsx │ │ │ ├── button.tsx │ │ │ ├── buttonVariants.tsx │ │ │ ├── card.tsx │ │ │ ├── command.tsx │ │ │ ├── dialog.tsx │ │ │ ├── form.tsx │ │ │ ├── icons.tsx │ │ │ ├── input.tsx │ │ │ ├── label.tsx │ │ │ ├── menubar.tsx │ │ │ ├── popover.tsx │ │ │ ├── progress.tsx │ │ │ ├── radio-group.tsx │ │ │ ├── scroll-area.tsx │ │ │ ├── select.tsx │ │ │ ├── separator.tsx │ │ │ ├── sheet.tsx │ │ │ ├── slider.tsx │ │ │ ├── switch.tsx │ │ │ ├── tabs.tsx │ │ │ ├── textarea.tsx │ │ │ ├── toast.tsx │ │ │ ├── toaster.tsx │ │ │ └── tooltip.tsx │ │ ├── context/ │ │ │ ├── ChatInputContext.tsx │ │ │ ├── LibraryContext.tsx │ │ │ ├── SysSettingsContext.tsx │ │ │ ├── UserClientProviders.tsx │ │ │ ├── UserContext.tsx │ │ │ ├── ViewContext.tsx │ │ │ ├── useChatInput.tsx │ │ │ ├── useLibrary.tsx │ │ │ ├── useSysSettings.tsx │ │ │ ├── useUser.tsx │ │ │ └── useView.tsx │ │ ├── data/ │ │ │ ├── models.ts │ │ │ └── sysSpecs.ts │ │ ├── electron/ │ │ │ ├── authentication/ │ │ │ │ ├── devApi.ts │ │ │ │ ├── secret.ts │ │ │ │ └── token.ts │ │ │ ├── crawl/ │ │ │ │ ├── cancelWebcrawl.ts │ │ │ │ └── webcrawl.ts │ │ │ ├── db.ts │ │ │ ├── embedding/ │ │ │ │ ├── cancelEmbed.ts │ │ │ │ └── vectorstoreQuery.ts │ │ │ ├── handlers/ │ │ │ │ ├── azureHandlers.ts │ │ │ │ ├── chatHandlers.ts │ │ │ │ ├── closeEventHandler.ts │ │ │ │ ├── collectionHandlers.ts │ │ │ │ ├── customApiHandlers.ts │ │ │ │ ├── dbHandlers.ts │ │ │ │ ├── fileHandlers.ts │ │ │ │ ├── handlers.test.ts │ │ │ │ ├── ipcHandlers.ts │ │ │ │ ├── localModelHandlers.ts │ │ │ │ ├── menuHandlers.ts │ │ │ │ ├── ollamaHandlers.ts │ │ │ │ ├── openRouterHandlers.ts │ │ │ │ └── voiceHandlers.ts │ │ │ ├── helpers/ │ │ │ │ └── spawnAsync.ts │ │ │ ├── llms/ │ │ │ │ ├── agentLayer/ │ │ │ │ │ ├── anthropicAgent.ts │ │ │ │ │ ├── geminiAgent.ts │ │ │ │ │ ├── ollamaAgent.ts │ │ │ │ │ ├── openAiAgent.ts │ │ │ │ │ └── tools/ │ │ │ │ │ └── websearch.ts │ │ │ │ ├── apiCheckProviders/ │ │ │ │ │ ├── anthropic.ts │ │ │ │ │ ├── deepseek.ts │ │ │ │ │ ├── gemini.ts │ │ │ │ │ ├── openai.ts │ │ │ │ │ ├── openrouter.ts │ │ │ │ │ └── xai.ts │ │ │ │ ├── chatCompletion.ts │ │ │ │ ├── generateTitle.ts │ │ │ │ ├── keyValidation.ts │ │ │ │ ├── llmHelpers/ │ │ │ │ │ ├── addAssistantMessage.ts │ │ │ │ │ ├── addUserMessage.ts │ │ │ │ │ ├── collectionData.ts │ │ │ │ │ ├── countMessageTokens.ts │ │ │ │ │ ├── getUserPrompt.ts │ │ │ │ │ ├── ifNewConvo.ts │ │ │ │ │ ├── prepMessages.ts │ │ │ │ │ ├── providerInit.ts │ │ │ │ │ ├── providersMap.ts │ │ │ │ │ ├── returnReasoningPrompt.ts │ │ │ │ │ ├── returnSystemPrompt.ts │ │ │ │ │ ├── sendMessageChunk.ts │ │ │ │ │ └── truncateMessages.ts │ │ │ │ ├── llms.ts │ │ │ │ ├── providers/ │ │ │ │ │ ├── anthropic.ts │ │ │ │ │ ├── azureOpenAI.ts │ │ │ │ │ ├── customEndpoint.ts │ │ │ │ │ ├── deepseek.ts │ │ │ │ │ ├── externalOllama.ts │ │ │ │ │ ├── gemini.ts │ │ │ │ │ ├── localModel.ts │ │ │ │ │ ├── ollama.ts │ │ │ │ │ ├── openai.ts │ │ │ │ │ ├── openrouter.ts │ │ │ │ │ └── xai.ts │ │ │ │ └── reasoningLayer/ │ │ │ │ └── openAiChainOfThought.ts │ │ │ ├── loadingWindow.ts │ │ │ ├── localLLMs/ │ │ │ │ ├── getDirModels.ts │ │ │ │ ├── loadModel.ts │ │ │ │ ├── modelInfo.ts │ │ │ │ └── unloadModel.ts │ │ │ ├── main.ts │ │ │ ├── mainWindow.test.ts │ │ │ ├── mainWindow.ts │ │ │ ├── menu.ts │ │ │ ├── ollama/ │ │ │ │ ├── checkOllama.ts │ │ │ │ ├── fetchLocalModels.ts │ │ │ │ ├── getRunningModels.ts │ │ │ │ ├── isOllamaRunning.ts │ │ │ │ ├── ollamaPath.ts │ │ │ │ ├── pullModel.ts │ │ │ │ ├── runOllama.ts │ │ │ │ ├── unloadAllModels.ts │ │ │ │ └── unloadModel.ts │ │ │ ├── pathResolver.ts │ │ │ ├── preload.cts │ │ │ ├── python/ │ │ │ │ ├── ensurePythonAndVenv.ts │ │ │ │ ├── extractFromAsar.ts │ │ │ │ ├── getLinuxPackageManager.ts │ │ │ │ ├── ifFedora.ts │ │ │ │ ├── installDependencies.ts │ │ │ │ ├── installLlamaCpp.ts │ │ │ │ ├── killProcessOnPort.ts │ │ │ │ ├── python.test.ts │ │ │ │ ├── runWithPrivileges.ts │ │ │ │ └── startAndStopPython.ts │ │ │ ├── resourceManager.ts │ │ │ ├── specs/ │ │ │ │ └── systemSpecs.ts │ │ │ ├── storage/ │ │ │ │ ├── deleteCollection.ts │ │ │ │ ├── getFiles.ts │ │ │ │ ├── getUserFiles.ts │ │ │ │ ├── newFile.ts │ │ │ │ ├── openCollectionFolder.ts │ │ │ │ ├── removeFileorFolder.ts │ │ │ │ ├── renameFile.ts │ │ │ │ └── websiteFetch.ts │ │ │ ├── tray.test.ts │ │ │ ├── tray.ts │ │ │ ├── tsconfig.json │ │ │ ├── util.ts │ │ │ ├── voice/ │ │ │ │ └── audioTranscription.ts │ │ │ └── youtube/ │ │ │ └── youtubeIngest.ts │ │ ├── hooks/ │ │ │ ├── use-toast.ts │ │ │ ├── useAppInitialization.tsx │ │ │ ├── useChatLogic.ts │ │ │ ├── useChatManagement.ts │ │ │ ├── useConversationManagement.ts │ │ │ ├── useModelManagement.ts │ │ │ ├── useStatistics.tsx │ │ │ └── useUIState.ts │ │ ├── lib/ │ │ │ ├── shikiHightlight.ts │ │ │ └── utils.ts │ │ ├── loading.html │ │ ├── types/ │ │ │ └── contextTypes/ │ │ │ ├── LibraryContextTypes.ts │ │ │ ├── SystemSettingsTypes.ts │ │ │ ├── UserContextType.ts │ │ │ └── UserViewTypes.ts │ │ └── utils/ │ │ ├── chatUtilts.ts │ │ └── webAudioRecorder.ts │ ├── tailwind.config.js │ ├── tsconfig.app.json │ ├── tsconfig.json │ ├── tsconfig.node.json │ ├── types.d.ts │ ├── vite.config.d.ts │ ├── vite.config.js │ └── vite.config.ts ├── LICENSE └── README.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Python cache files __pycache__/ *.py[cod] *$py.class .venv /Frontend/node_modules /Frontend/dist .env.local database.sqlite Backend/venv Backend/venvs models/* .DS_Store *.tsbuildinfo Collections/* FileCollections .dev.secret VectorStores/* Frontend/chroma_db/chroma.sqlite3 monitor_resources.ps1 Frontend/models/* Backend/models/* test_curl.txt ================================================ FILE: Backend/.gitignore ================================================ venv testData ================================================ FILE: Backend/ensure_dependencies.py ================================================ import sys import os import subprocess import asyncio from concurrent.futures import ThreadPoolExecutor, as_completed import warnings import logging # Filter transformers model warnings warnings.filterwarnings('ignore', category=UserWarning) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' # Configure logging to handle progress messages logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def find_python310(): python_commands = ["python3.12", "python3"] if sys.platform != "win32" else [ "python3.11", "py -3.11", "python"] for cmd in python_commands: try: result = subprocess.run( [cmd, "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if sys.platform == "win32": if "Python 3.11" in result.stdout: return cmd else: if "Python 3.12" in result.stdout: return cmd except: continue return None def create_venv(venv_path=None): if venv_path is None: venv_path = os.path.join(os.path.dirname(__file__), 'venv') if not os.path.exists(venv_path): print("Creating virtual environment...") python310 = find_python310() if not python310: if sys.platform == "win32": raise RuntimeError( "Python 3.11 is required but not found. Please install Python 3.11.") else: raise RuntimeError( "Python 3.12 is required but not found. Please install Python 3.12.") subprocess.check_call([python310, "-m", "venv", venv_path]) print(f"Created virtual environment with {python310}") return venv_path def get_venv_python(venv_path): if sys.platform == "win32": return os.path.join(venv_path, "Scripts", "python.exe") return os.path.join(venv_path, "bin", "python") def install_package(python_path, package): try: subprocess.check_call( [python_path, '-m', 'pip', 'install', '--no-deps', '--upgrade-strategy', 'only-if-needed', package], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) return package, None except subprocess.CalledProcessError as e: return package, str(e) def get_installed_packages(python_path): result = subprocess.run( [python_path, '-m', 'pip', 'list', '--format=freeze'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) return {line.split('==')[0].lower(): line.split('==')[1] for line in result.stdout.splitlines()} async def async_init_store(): try: # Suppress model initialization warnings import transformers from src.vectorstorage.init_store import init_store transformers.logging.set_verbosity_error() logging.getLogger( "transformers.modeling_utils").setLevel(logging.ERROR) # Configure huggingface_hub logging hf_logger = logging.getLogger("huggingface_hub") hf_logger.setLevel(logging.INFO) sys.stdout.write( "Downloading initial embedding model (HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5) ...|85\n") sys.stdout.flush() # Redirect stderr to capture progress messages with open(os.devnull, 'w') as devnull: old_stderr = sys.stderr sys.stderr = devnull try: model_path = await init_store() sys.stdout.write( f"Model downloaded successfully to {model_path}|95\n") finally: sys.stderr = old_stderr sys.stdout.flush() except Exception as e: sys.stdout.write(f"Error downloading model: {str(e)}|85\n") sys.stdout.flush() raise e def get_package_version(python_path, package_name): try: result = subprocess.run( [python_path, '-m', 'pip', 'show', package_name], capture_output=True, text=True ) for line in result.stdout.split('\n'): if line.startswith('Version: '): version = line.split('Version: ')[1].strip() # Handle CUDA variants of PyTorch if package_name == 'torch' and '+cu' in version: # Strip CUDA suffix for version comparison version = version.split('+')[0] return version except: return None return None def install_requirements(custom_venv_path=None): try: venv_path = create_venv(custom_venv_path) python_path = get_venv_python(venv_path) # Install core dependencies first requirements_path = os.path.join( os.path.dirname(__file__), 'requirements.txt') # Handle remaining requirements with open(requirements_path, 'r') as f: requirements = [ line.strip() for line in f if line.strip() and not line.startswith('#') ] total_deps = len(requirements) sys.stdout.write(f"Total packages to process: {total_deps}|50\n") sys.stdout.flush() installed_packages = get_installed_packages(python_path) to_install = [] for req in requirements: pkg_name = req.split('==')[0] if '==' in req else req if pkg_name.lower() not in installed_packages: to_install.append(req) completed_deps = total_deps - len(to_install) progress = 50 + (completed_deps / total_deps) * \ 30 # Scale from 50 to 80 sys.stdout.write(f"Checked installed packages|{progress:.1f}\n") sys.stdout.flush() with ThreadPoolExecutor(max_workers=5) as executor: future_to_pkg = {executor.submit( install_package, python_path, req): req for req in to_install} for future in as_completed(future_to_pkg): pkg = future_to_pkg[future] pkg_name = pkg.split('==')[0] if '==' in pkg else pkg result, error = future.result() completed_deps += 1 progress = 50 + (completed_deps / total_deps) * \ 30 # Scale from 50 to 80 if error: sys.stdout.write( f"Error installing {pkg_name}: {error}|{progress:.1f}\n") else: sys.stdout.write(f"Installed {pkg_name}|{progress:.1f}\n") sys.stdout.flush() # Now we can safely import init_store after all dependencies are installed sys.stdout.write( "All dependencies installed, initializing model store...|85\n") sys.stdout.flush() # Initialize the store to download the model asyncio.run(async_init_store()) sys.stdout.write( "Dependencies installed and model initialized successfully!|99\n") sys.stdout.flush() except Exception as e: sys.stdout.write(f"Error installing dependencies: {str(e)}|0\n") sys.stdout.flush() sys.exit(1) if __name__ == "__main__": custom_venv_path = sys.argv[1] if len(sys.argv) > 1 else None install_requirements(custom_venv_path) ================================================ FILE: Backend/main.py ================================================ import logging from src.authentication.api_key_authorization import api_key_auth from src.authentication.token import verify_token, verify_token_or_api_key from src.data.database.checkAPIKey import check_api_key from src.data.dataFetch.youtube import youtube_transcript from src.endpoint.deleteStore import delete_vectorstore_collection from src.endpoint.models import EmbeddingRequest, QueryRequest, ChatCompletionRequest, VectorStoreQueryRequest, DeleteCollectionRequest, YoutubeTranscriptRequest, WebCrawlRequest, ModelLoadRequest from src.endpoint.embed import embed from src.endpoint.vectorQuery import query_vectorstore from src.endpoint.devApiCall import rag_call, llm_call, vector_call from src.endpoint.transcribe import transcribe_audio from src.endpoint.webcrawl import webcrawl from src.models.manager import model_manager from fastapi import FastAPI, Depends, File, UploadFile, Request, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse import asyncio import os import signal import sys import psutil import threading import uvicorn import json from src.endpoint.api import chat_completion_stream app = FastAPI() embedding_task = None embedding_event = None crawl_task = None crawl_event = None origins = ["http://localhost", "http://127.0.0.1"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], max_age=3600, # Cache preflight requests for 1 hour expose_headers=["*"] ) # Configure FastAPI app settings for long-running requests @app.middleware("http") async def timeout_middleware(request: Request, call_next): try: # Set a long timeout for the request # 1 hour timeout response = await asyncio.wait_for(call_next(request), timeout=3600) return response except asyncio.TimeoutError: return JSONResponse( status_code=504, content={"detail": "Request timeout"} ) logger = logging.getLogger(__name__) @app.post("/chat/completions") async def chat_completion(request: ChatCompletionRequest, user_id: str = Depends(verify_token_or_api_key)) -> StreamingResponse: """Stream chat completion from the model""" print("Chat completion request received") print(user_id, request) info = model_manager.get_model_info() print(info) if request.model != info["model_name"]: model_load_request = ModelLoadRequest( model_name=request.model) model, tokenizer = model_manager.load_model(model_load_request) print("Model mismatch") return {"status": "error", "message": "Model mismatch"} if user_id is None: return {"status": "error", "message": "Unauthorized"} print("Authorized") print(request) return StreamingResponse( chat_completion_stream(request), media_type="text/event-stream" ) @app.get("/model-info") async def get_model_info(user_id: str = Depends(verify_token_or_api_key)): if user_id is None: return {"status": "error", "message": "Unauthorized"} """Get information about the currently loaded model""" return JSONResponse(content=model_manager.get_model_info()) @app.post("/load-model") async def load_model_endpoint(request: ModelLoadRequest, user_id: str = Depends(verify_token_or_api_key)): if user_id is None: return {"status": "error", "message": "Unauthorized"} """Load a model with the specified configuration""" print("Loading model") print(request) model_type = request.model_type or "auto" if model_type != "auto": is_compatible, message = model_manager.check_platform_compatibility( model_type) logger.info(f"is_compatible: {is_compatible}, message: {message}") # Return early if platform is not compatible if not is_compatible: response_data = model_manager._make_json_serializable({ "status": "error", "message": f"Cannot load model: {message}", "model_info": model_manager.get_model_info() }) return JSONResponse(content=response_data) try: model, tokenizer = model_manager.load_model(request) response_data = model_manager._make_json_serializable({ "status": "success", "message": f"Successfully loaded model {request.model_name}", "model_info": model_manager.get_model_info() }) print(response_data) logger.info(response_data) return JSONResponse(content=response_data) except Exception as e: response_data = model_manager._make_json_serializable({ "status": "error", "message": str(e), "model_info": model_manager.get_model_info() }) return JSONResponse(status_code=500, content=response_data) @app.post("/unload-model") async def unload_model_endpoint(user_id: str = Depends(verify_token_or_api_key)): if user_id is None: return {"status": "error", "message": "Unauthorized"} """Unload the currently loaded model""" try: model_manager.clear_model() return JSONResponse(content={ "status": "success", "message": "Model unloaded successfully", "model_info": model_manager.get_model_info() }) except Exception as e: return JSONResponse( status_code=500, content={ "status": "error", "message": str(e), "model_info": model_manager.get_model_info() } ) @app.post("/webcrawl") async def webcrawl_endpoint(data: WebCrawlRequest, user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} global crawl_task, crawl_event if crawl_task is not None: return {"status": "error", "message": "A crawl process is already running"} crawl_event = asyncio.Event() async def event_generator(): global crawl_task, crawl_event try: for result in webcrawl(data, crawl_event): if crawl_event.is_set(): yield f"data: {{'type': 'cancelled', 'message': 'Crawl process cancelled'}}\n\n" break yield f"{result}\n\n" await asyncio.sleep(0.1) except Exception as e: error_data = { "status": "error", "data": { "message": str(e) } } yield f"data: {json.dumps(error_data)}\n\n" finally: crawl_task = None crawl_event = None response = StreamingResponse( event_generator(), media_type="text/event-stream") crawl_task = asyncio.create_task(event_generator().__anext__()) return response @app.post("/transcribe") async def transcribe_audio_endpoint(audio_file: UploadFile = File(...), model_name: str = "base", user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} return await transcribe_audio(audio_file, model_name) @app.post("/embed") async def add_embedding(data: EmbeddingRequest, user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} print("Metadata:", data.metadata) global embedding_task, embedding_event if embedding_task is not None: return {"status": "error", "message": "An embedding process is already running"} embedding_event = asyncio.Event() async def event_generator(): global embedding_task, embedding_event try: async for result in embed(data): if embedding_event.is_set(): yield f"data: {{'type': 'cancelled', 'message': 'Embedding process cancelled'}}\n\n" break if result["status"] == "progress": progress_data = result["data"] yield f"data: {{'type': 'progress', 'chunk': {progress_data['chunk']}, 'totalChunks': {progress_data['total_chunks']}, 'percent_complete': '{progress_data['percent_complete']}', 'est_remaining_time': '{progress_data['est_remaining_time']}'}}\n\n" else: yield f"data: {{'type': '{result['status']}', 'message': '{result['message']}'}}\n\n" await asyncio.sleep(0.1) # Prevent overwhelming the connection except Exception as e: logger.error(f"Error in embedding process: {str(e)}") yield f"data: {{'type': 'error', 'message': '{str(e)}'}}\n\n" finally: embedding_task = None embedding_event = None logger.info("Embedding task cleanup completed") response = StreamingResponse( event_generator(), media_type="text/event-stream" ) # Set response headers for better connection handling response.headers["Cache-Control"] = "no-cache" response.headers["Connection"] = "keep-alive" response.headers["X-Accel-Buffering"] = "no" response.headers["Transfer-Encoding"] = "chunked" embedding_task = asyncio.create_task(event_generator().__anext__()) return response @app.post("/youtube-ingest") async def youtube_ingest(data: YoutubeTranscriptRequest, user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} async def event_generator(): try: for result in youtube_transcript(data): if result["status"] == "progress": progress_data = result["data"] yield f"data: {{'type': 'progress', 'chunk': {progress_data['chunk']}, 'totalChunks': {progress_data['total_chunks']}, 'percent_complete': '{progress_data['percent_complete']}', 'message': '{progress_data['message']}'}}\n\n" else: yield f"data: {{'type': '{result['status']}', 'message': '{result['message']}'}}\n\n" await asyncio.sleep(0.1) except Exception as e: yield f"data: {{'type': 'error', 'message': '{str(e)}'}}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") @app.post("/cancel-embed") async def cancel_embedding(user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} global embedding_task, embedding_event if embedding_event: embedding_event.set() return {"status": "success", "message": "Embedding process cancelled"} return {"status": "error", "message": "No embedding process running"} @app.post("/restart-server") async def restart_server(user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} def restart(): pid = os.getpid() parent = psutil.Process(pid) # Kill all child processes for child in parent.children(recursive=True): child.kill() # Kill the current process os.kill(pid, signal.SIGTERM) # Start a new instances python = sys.executable os.execl(python, python, *sys.argv) threading.Thread(target=restart).start() return {"status": "success", "message": "Server restart initiated"} @app.post("/vector-query") async def vector_query(data: VectorStoreQueryRequest, user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} try: result = query_vectorstore(data, data.is_local) return result except Exception as e: print(f"Error querying vectorstore: {str(e)}") return {"status": "error", "message": str(e)} @app.post("/delete-collection") async def delete_collection(data: DeleteCollectionRequest, user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} print("Authorized") return delete_vectorstore_collection(data) @app.post("/api/vector") async def api_vector(query_request: QueryRequest, user_id: str = Depends(api_key_auth)): if user_id is None: return {"status": "error", "message": "Unauthorized"} """ check to see if the userId has API key in SQLite """ if not query_request.collection_name: print("No collection name provided") return {"status": "error", "message": "No collection name provided"} if check_api_key(int(user_id)) == False: print("Unauthorized") return {"status": "error", "message": "Unauthorized"} print("Authorized") return vector_call(query_request, user_id) @app.post("/api/llm") async def api_llm(query_request: ChatCompletionRequest, user_id: str = Depends(api_key_auth)): if user_id is None: return {"status": "error", "message": "Unauthorized"} """ check to see if the userId has API key in SQLite """ if not query_request.model: print("No model provided") return {"status": "error", "message": "No model provided"} if check_api_key(int(user_id)) == False: print("Unauthorized") return {"status": "error", "message": "Unauthorized"} print("Authorized") return await llm_call(query_request, user_id) @app.post("/api/rag") async def api_rag(query_request: QueryRequest, user_id: str = Depends(api_key_auth)): if user_id is None: return {"status": "error", "message": "Unauthorized"} """ check to see if the userId has API key in SQLite """ if not query_request.model: print("No model provided") return {"status": "error", "message": "No model provided"} if not query_request.collection_name: print("No collection name provided") return {"status": "error", "message": "No collection name provided"} if check_api_key(int(user_id)) == False: print("Unauthorized") return {"status": "error", "message": "Unauthorized"} print("Authorized") return await rag_call(query_request, user_id) @app.post("/cancel-crawl") async def cancel_crawl(user_id: str = Depends(verify_token)): if user_id is None: return {"status": "error", "message": "Unauthorized"} global crawl_task, crawl_event if crawl_event: crawl_event.set() return {"status": "success", "message": "Crawl process cancelled"} return {"status": "error", "message": "No crawl process running"} if __name__ == "__main__": print("Starting server...") uvicorn.run( app, host="127.0.0.1", port=47372, timeout_keep_alive=3600, timeout_graceful_shutdown=300, limit_concurrency=10, backlog=2048 ) ================================================ FILE: Backend/requirements.txt ================================================ annotated-types==0.7.0 anyio==4.6.2.post1 asgiref==3.8.1 backoff==2.2.1 bcrypt==4.2.1 build==1.2.2.post1 cachetools==5.5.0 certifi==2024.8.30 charset-normalizer==3.4.0 chromadb==0.6.3 chroma-hnswlib==0.7.6 click==8.1.7 coloredlogs==15.0.1 Deprecated==1.2.15 dnspython==2.7.0 durationpy==0.9 ecdsa==0.19.0 email_validator==2.2.0 exceptiongroup==1.2.2 fastapi==0.115.6 fastapi-cli==0.0.6 filelock==3.16.1 flatbuffers==24.3.25 fsspec==2024.10.0 google-auth==2.36.0 googleapis-common-protos==1.66.0 grpcio==1.68.1 h11==0.14.0 httpcore==1.0.7 httptools==0.6.4 httpx==0.28.0 huggingface-hub==0.26.5 humanfriendly==10.0 idna==3.10 importlib_metadata==8.5.0 importlib_resources==6.4.5 iniconfig==2.0.0 Jinja2==3.1.5 kubernetes==31.0.0 markdown-it-py==3.0.0 MarkupSafe==3.0.2 mdurl==0.1.2 mmh3==5.0.1 monotonic==1.6 mpmath==1.3.0 numba==0.58.1 oauthlib==3.2.2 onnxruntime==1.20.1 opentelemetry-api==1.28.2 opentelemetry-exporter-otlp-proto-common==1.28.2 opentelemetry-exporter-otlp-proto-grpc==1.28.2 opentelemetry-instrumentation==0.49b2 opentelemetry-instrumentation-asgi==0.49b2 opentelemetry-instrumentation-fastapi==0.49b2 opentelemetry-proto==1.28.2 opentelemetry-sdk==1.28.2 opentelemetry-semantic-conventions==0.49b2 opentelemetry-util-http==0.49b2 orjson==3.10.12 overrides==7.7.0 packaging==24.2 passlib==1.7.4 pluggy==1.5.0 posthog==3.7.4 protobuf==5.29.1 pyasn1==0.6.1 pyasn1_modules==0.4.1 pydantic>=2.9.0,<3.0.0 pydantic_core==2.14.6 Pygments==2.18.0 PyPika==0.48.9 pyproject_hooks==1.2.0 pytest==8.3.4 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 PyJWT==2.10.1 python-multipart==0.0.19 PyYAML==6.0.2 requests==2.32.3 requests-oauthlib==2.0.0 rich==13.9.4 rich-toolkit==0.11.3 rsa==4.9 shellingham==1.5.4 six==1.17.0 sniffio==1.3.1 starlette==0.41.3 sympy==1.13.3 tenacity==9.0.0 tokenizers==0.21.0 tomli==2.2.1 tqdm==4.67.1 typer==0.15.1 urllib3==2.2.3 uvicorn==0.32.1 watchfiles==1.0.0 websocket-client==1.8.0 websockets==14.1 wrapt==1.17.0 zipp==3.21.0 pypdf[full]==5.1.0 python-docx==0.8.11 beautifulsoup4==4.12.2 markdown==3.5.1 python-pptx==0.6.21 openpyxl==3.1.2 lxml==5.3.0 pandas==2.2.3 pytz==2024.2 pillow==11.0.0 soupsieve==2.6 openai==1.58.1 distro==1.9.0 nest_asyncio==1.5.6 hypercorn==0.14.3 toml==0.10.2 h2==4.1.0 hyperframe==6.0.1 hpack==4.0.0 http3==0.6.7 h11==0.14.0 httpcore==1.0.7 sentence-transformers==3.3.1 threadpoolctl==3.5.0 joblib==1.4.2 scipy==1.15.1 httpx==0.28.0 priority==2.0.0 wsproto==1.2.0 jiter==0.8.2 langchain==0.3.16 langchain-text-splitters==0.3.4 langchain_core==0.3.28 langsmith==0.2.3 requests_toolbelt==1.0.0 jsonpatch==1.33 jsonpointer==3.0.0 langchain_community==0.3.16 tiktoken==0.8.0 regex==2024.11.6 langchain-openai==0.2.14 langchain-chroma==0.2.1 psutil==6.1.1 ollama==0.4.4 docx2txt==0.8 yt-dlp==2024.12.23 webvtt-py==0.4.6 langchain-ollama==0.2.2 openai-whisper==20240930 accelerate>=0.20.3 bitsandbytes>=0.41.1 safetensors>=0.4.0 llvmlite==0.43.0 einops==0.8.0 optimum==1.23.3 datasets==3.2.0 pyarrow==18.1.0 multiprocess==0.70.17 dill>=0.3.6 aiohttp==3.11.11 multidict==6.1.0 attrs>=23.1.0 yarl==1.18.3 propcache==0.2.1 async-timeout==5.0.1 aiohappyeyeballs==2.4.4 aiosignal==1.3.2 frozenlist==1.5.0 xxhash==3.5.0 diskcache==5.6.3 hqq==0.2.2 termcolor==2.5.0 langchain-huggingface==0.1.2 pypdf==5.2.0 ================================================ FILE: Backend/src/authentication/api_key_authorization.py ================================================ from fastapi import Depends from fastapi.security import OAuth2PasswordBearer from typing import Optional import jwt import logging import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") SECRET_KEY = os.environ.get("SECRET_KEY") if not SECRET_KEY: raise RuntimeError("Could not get JWT secret for API key authorization") async def get_optional_token(token: Optional[str] = Depends(oauth2_scheme)): return token async def api_key_auth(token: Optional[str] = Depends(get_optional_token)): if token is None: return None try: payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) user_id: str = payload.get("userId") logger.info(f"User ID: {user_id}") if user_id is None: return None return user_id except jwt.exceptions.InvalidTokenError: logger.error("Invalid token") return None ================================================ FILE: Backend/src/authentication/token.py ================================================ from fastapi import Depends, Request from fastapi.security import OAuth2PasswordBearer from typing import Optional import os import jwt import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # Get secret from environment variable SECRET_KEY = os.environ.get("JWT_SECRET") if not SECRET_KEY: raise RuntimeError("JWT_SECRET environment variable is not set") async def get_optional_token(token: Optional[str] = Depends(oauth2_scheme)): return token async def verify_token(token: Optional[str] = Depends(get_optional_token)): if token is None: return None try: payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) print(f"Payload: {payload}") user_id: str = payload.get("userId") logger.info(f"User ID: {user_id}") if user_id is None: return None return user_id except jwt.exceptions.InvalidTokenError: logger.error("Invalid token") return None async def optional_auth(request: Request): if "Authorization" in request.headers: token = request.headers["Authorization"].split("Bearer ")[1] try: payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) return payload.get("userId") except jwt.exceptions.InvalidTokenError: return None return None async def verify_token_or_api_key(token: Optional[str] = Depends(get_optional_token)): """Verify token using normal auth, falling back to API key auth if that fails""" # Try normal token verification first user_id = await verify_token(token) if user_id: return user_id # Fall back to API key verification from src.authentication.api_key_authorization import api_key_auth return await api_key_auth(token) ================================================ FILE: Backend/src/data/dataFetch/webcrawler.py ================================================ import os import json import logging import requests from bs4 import BeautifulSoup from urllib.parse import urljoin, urlparse import time import threading import concurrent.futures from concurrent.futures import ThreadPoolExecutor from queue import Queue, Empty import threading class WebCrawler: def __init__(self, base_url, user_id, user_name, collection_id, collection_name, max_workers, cancel_event=None): self.base_url = base_url self.output_dir = self._get_collection_path( user_id, user_name, collection_id, collection_name) self.visited_urls = set() self.failed_urls = set() self.delay = 0 # Reduced delay since we're rate limiting with max_workers self.max_workers = 35 self.url_queue = Queue() self.url_lock = threading.Lock() self.progress_bar = None self.total_urls = 0 self.current_urls = 0 self.update_callback = None self.cancel_event = cancel_event # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) # Create output directory if it doesn't exist os.makedirs(self.output_dir, exist_ok=True) def _get_collection_path(self, user_id, user_name, collection_id, collection_name): """Generate the collection path matching the frontend structure""" app_data_path = os.path.abspath(os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".." )) return os.path.join( app_data_path, "..", "FileCollections", f"{user_id}_{user_name}", f"{collection_id}_{collection_name}" ) def _print_progress(self): """Print progress as JSON""" if self.total_urls > 0: percent = (self.current_urls / self.total_urls) * 100 progress_data = { "status": "progress", "data": { "message": f"Part 1 of 2: Scraping page {self.current_urls} out of {self.total_urls} from {self.base_url}", "chunk": self.current_urls, "total_chunks": self.total_urls, "percent_complete": f"{percent:.1f}%" } } json_str = json.dumps(progress_data) print(f"data: {json_str}") return progress_data def is_valid_url(self, url): """Check if URL belongs to the same domain and is a documentation page""" # Remove fragment identifier (#) and anything that follows url = url.split('#')[0] if not url: # Skip empty URLs after fragment removal return False # First check if URL starts with base_url if not url.startswith(self.base_url): logging.debug(f"Filtered URL (not starting with base URL): {url}") return False # Remove trailing slashes for consistency url = url.rstrip('/') # Skip obviously invalid URLs invalid_patterns = [ '.pdf', '.zip', '.png', '.jpg', # File extensions 'github.com', 'twitter.com', # External sites '/api/', '/examples/', # Common non-doc paths '?', 'mailto:', 'javascript:' # Special URLs ] if any(pattern in url for pattern in invalid_patterns): logging.debug(f"Filtered URL (invalid pattern): {url}") return False # Ensure not a resource file return not url.endswith(('js', 'css', 'json')) def save_page(self, url, html_content): """Save the HTML content to a file""" try: # Create base_url_docs directory parsed_base_url = urlparse(self.base_url) base_url_dir = parsed_base_url.netloc.replace(".", "_") + "_docs" base_dir = os.path.join(self.output_dir, base_url_dir) os.makedirs(base_dir, exist_ok=True) # Create a file path based on the URL structure parsed_url = urlparse(url) path_parts = parsed_url.path.strip('/').split('/') # Create subdirectories if needed current_dir = base_dir for part in path_parts[:-1]: current_dir = os.path.join(current_dir, part) os.makedirs(current_dir, exist_ok=True) # Save the file filename = path_parts[-1] if path_parts else 'index' filepath = os.path.join(current_dir, f"{filename}.html") with open(filepath, 'w', encoding='utf-8') as f: f.write(html_content) return True except Exception as e: logging.error(f"Error saving {url}: {str(e)}") return False def get_links(self, soup, current_url): """Extract valid documentation links from the page""" links = set() for a in soup.find_all('a', href=True): # Get the full URL url = urljoin(current_url, a['href']) # Remove fragment identifier (#) and anything that follows url = url.split('#')[0] # Skip empty URLs after fragment removal if not url: continue # Remove trailing slashes for consistency url = url.rstrip('/') # Only add if it's valid and not already visited if self.is_valid_url(url) and url not in self.visited_urls: links.add(url) return links def scrape_page(self, url): """Scrape a single page and return its content and links""" try: response = requests.get(url, timeout=10) response.raise_for_status() html_content = response.text # Create BeautifulSoup object with the response text soup = BeautifulSoup(html_content, 'html.parser') # Remove unwanted elements before getting links for element in soup.find_all(['header', 'footer', 'nav', 'script', 'style', 'meta']): if element is not None: element.decompose() # Get links from the cleaned soup links = self.get_links(soup, url) return soup, links except Exception as e: error_data = { "status": "error", "data": { "message": str(e) } } print(f"data: {json.dumps(error_data)}") logging.error(f"Error scraping {url}: {str(e)}") self.failed_urls.add(url) return None, set() def scrape(self): """Main scraping method using thread pool""" # Initialize with start URL self.url_queue.put(self.base_url) self.total_urls = 1 # Initialize with 1 for the base URL self.current_urls = 0 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: active_tasks = set() while True: try: # Check for cancellation if self.cancel_event and self.cancel_event.is_set(): break # Get next URL with timeout try: current_url = self.url_queue.get(timeout=5) except Empty: # If no active tasks and queue is empty, we're done if not active_tasks: break continue if current_url in self.visited_urls: continue with self.url_lock: if current_url in self.visited_urls: continue self.visited_urls.add(current_url) yield self._print_progress() # Submit the scraping task to thread pool future = executor.submit(self._process_url, current_url) active_tasks.add(future) future.add_done_callback(lambda f: active_tasks.remove(f)) future.add_done_callback(self._update_progress) except Exception as e: error_data = { "status": "error", "data": { "message": str(e) } } print(f"data: {json.dumps(error_data)}") logging.error(f"Error in scrape loop: {str(e)}") continue # Wait for remaining tasks to complete for future in concurrent.futures.as_completed(list(active_tasks)): try: future.result() except Exception as e: error_data = { "status": "error", "data": { "message": str(e) } } print(f"data: {json.dumps(error_data)}") logging.error(f"Error in remaining tasks: {str(e)}") def _update_progress(self, future): """Callback to update progress""" try: with self.url_lock: self.current_urls += 1 progress_data = self._print_progress() if progress_data: json_str = json.dumps(progress_data) print(f"data: {json_str}") except Exception as e: error_data = { "status": "error", "data": { "message": str(e) } } print(f"data: {json.dumps(error_data)}") def _process_url(self, url): """Process a single URL - called by thread pool""" try: # Check for cancellation if self.cancel_event and self.cancel_event.is_set(): return # Respectful delay time.sleep(self.delay) # Scrape the page soup, new_links = self.scrape_page(url) if soup is None: return # Save the page if self.save_page(url, str(soup)): # Add new links to queue with self.url_lock: for link in new_links: if link not in self.visited_urls and link not in self.url_queue.queue: self.url_queue.put(link) self.total_urls += 1 except Exception as e: error_data = { "status": "error", "data": { "message": str(e) } } print(f"data: {json.dumps(error_data)}") logging.error(f"Error processing URL {url}: {str(e)}") def save_progress(self): """Save progress information""" with open('scraping_progress.txt', 'w') as f: f.write(f"Visited URLs: {len(self.visited_urls)}\n") f.write(f"Failed URLs: {len(self.failed_urls)}\n") f.write("\nFailed URLs:\n") for url in self.failed_urls: f.write(f"{url}\n") ================================================ FILE: Backend/src/data/dataFetch/youtube.py ================================================ import os from src.endpoint.models import YoutubeTranscriptRequest from src.vectorstorage.vectorstore import get_vectorstore from src.vectorstorage.helpers.sanitizeCollectionName import sanitize_collection_name from langchain_core.documents import Document import yt_dlp import logging import requests import webvtt from io import StringIO from typing import Generator import json logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def _get_collection_path(user_id, user_name, collection_id, collection_name): """Generate the collection path matching the frontend structure""" app_data_path = os.path.abspath(os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".." )) return os.path.join( app_data_path, "..", "FileCollections", f"{user_id}_{user_name}", f"{collection_id}_{collection_name}" ) def youtube_transcript(request: YoutubeTranscriptRequest) -> Generator[dict, None, None]: """ Fetch video transcript and metadata using yt-dlp """ logger.info(f"Starting transcript fetch for URL: {request.url}") yield {"status": "progress", "data": {"message": f"Starting transcript fetch for URL: {request.url}", "chunk": 1, "total_chunks": 4, "percent_complete": "0%"}} ydl_opts = { 'writesubtitles': True, 'writeautomaticsub': True, 'subtitlesformat': 'vtt', 'skip_download': True, 'quiet': True, # Suppress yt-dlp's own output 'no_warnings': True # Suppress warnings } try: with yt_dlp.YoutubeDL(ydl_opts) as ydl: # Video info extraction (0-5%) yield {"status": "progress", "data": {"message": "Extracting video information...", "chunk": 1, "total_chunks": 4, "percent_complete": "5%"}} info = ydl.extract_info(request.url, download=False) video_info = f"Found video: '{info.get('title', 'Unknown')}' by {info.get('uploader', 'Unknown')}, duration: {info.get('duration', 'Unknown')} seconds" logger.info(video_info) yield {"status": "progress", "data": {"message": video_info, "chunk": 1, "total_chunks": 4, "percent_complete": "10%"}} # Get automatic captions if available subtitles = None if 'automatic_captions' in info and 'en' in info['automatic_captions']: logger.info("Using automatic captions") yield {"status": "progress", "data": {"message": "Found automatic captions, processing...", "chunk": 0, "total_chunks": 0, "percent_complete": "0%"}} subtitles = info['automatic_captions']['en'] # Fall back to manual subtitles if available elif 'subtitles' in info and 'en' in info['subtitles']: logger.info("Using manual subtitles") yield {"status": "progress", "data": {"message": "Found manual subtitles, processing...", "chunk": 0, "total_chunks": 0, "percent_complete": "0%"}} subtitles = info['subtitles']['en'] if not subtitles: error_msg = "No English subtitles or automatic captions available" logger.error(error_msg) raise Exception(error_msg) # Download the VTT format subtitles subtitle_url = None for fmt in subtitles: if fmt.get('ext') == 'vtt': subtitle_url = fmt['url'] break if not subtitle_url: error_msg = "No VTT format subtitles found" logger.error(error_msg) raise Exception(error_msg) # Update progress for subtitle download (10-15%) yield {"status": "progress", "data": {"message": "Downloading subtitles...", "chunk": 2, "total_chunks": 4, "percent_complete": "15%"}} # Download the VTT content response = requests.get(subtitle_url) if response.status_code != 200: error_msg = "Failed to download subtitles" logger.error(error_msg) raise Exception(error_msg) # Parse the VTT content vtt_content = response.text vtt_file = StringIO(vtt_content) vtt_captions = webvtt.read_buffer(vtt_file) # Start of transcript processing (15-35%) yield {"status": "progress", "data": {"message": "Processing subtitles...", "chunk": 2, "total_chunks": 4, "percent_complete": "15%"}} def clean_caption(text): # Remove common VTT artifacts and clean text text = ' '.join(text.split()) # Remove extra whitespace # Remove text within brackets (often contains sound effects or speaker labels) if text.startswith('[') and text.endswith(']'): return "" # Remove common YouTube caption artifacts text = text.replace('>>>', '').replace('>>', '') # Remove any remaining brackets and their contents while '[' in text and ']' in text: start = text.find('[') end = text.find(']') + 1 text = text[:start] + text[end:] return text.strip() def is_substantial_difference(text1, text2): # More aggressive deduplication if not text1 or not text2: return True # Convert to lowercase and split into words words1 = text1.lower().split() words2 = text2.lower().split() # If either text is too short, consider them different if len(words1) < 3 or len(words2) < 3: return True # Create word sequences for comparison seq1 = ' '.join(words1) seq2 = ' '.join(words2) # Check if one is contained within the other if seq1 in seq2 or seq2 in seq1: return False # Calculate word overlap words1_set = set(words1) words2_set = set(words2) overlap = len(words1_set.intersection(words2_set)) max_words = max(len(words1_set), len(words2_set)) # If more than 50% overlap, consider it a duplicate return (overlap / max_words) < 0.5 if max_words > 0 else True # Create documents from transcript chunks documents = [] total_captions = len(vtt_captions) processed_captions = 0 chunk_size = 60 # Increased chunk size to 60 seconds current_chunk = [] chunk_start = 0 chunk_count = 0 last_text = "" # Process captions with progress updates from 15-35% for caption in vtt_captions: cleaned_text = clean_caption(caption.text) if not cleaned_text: continue start_seconds = _time_to_seconds(caption.start) # Only add text if it's substantially different from the last added text if is_substantial_difference(last_text, cleaned_text): # Don't add if it's just a subset of any recent text in current chunk if not any(cleaned_text in existing or existing in cleaned_text for existing in current_chunk[-3:] if current_chunk): current_chunk.append(cleaned_text) last_text = cleaned_text # Create new chunk every chunk_size seconds or if chunk is getting too long if (start_seconds - chunk_start >= chunk_size and current_chunk) or \ (len(' '.join(current_chunk)) > 1000): # Limit chunk size to ~1000 chars if current_chunk: # Only create chunk if there's content chunk_count += 1 doc = Document( page_content=" ".join(current_chunk), metadata={ "title": info.get('title', ''), "description": info.get('description', ''), "author": info.get('uploader', ''), "source": request.url, "chunk_start": chunk_start, "chunk_end": start_seconds, "chunk_number": chunk_count } ) documents.append(doc) current_chunk = [] chunk_start = start_seconds last_text = "" processed_captions += 1 if processed_captions % 100 == 0: # Update every 100 captions # Progress from 15% to 35% percent = 15 + ((processed_captions / total_captions) * 20) yield {"status": "progress", "data": { "message": f"Processing transcript: {processed_captions}/{total_captions} captions", "chunk": 2, "total_chunks": 4, "percent_complete": f"{percent:.1f}%" }} # Add final chunk if any remains if current_chunk: chunk_count += 1 doc = Document( page_content=" ".join(current_chunk), metadata={ "title": info.get('title', ''), "description": info.get('description', ''), "author": info.get('uploader', ''), "source": request.url, "chunk_start": chunk_start, "chunk_end": _time_to_seconds(vtt_captions[-1].end), "chunk_number": chunk_count } ) documents.append(doc) # Vectorstore initialization (35-40%) yield {"status": "progress", "data": { "message": "Initializing vector database...", "chunk": 3, "total_chunks": 4, "percent_complete": "40%" }} # Store documents in ChromaDB collection_name = sanitize_collection_name( str(request.collection_name)) vectordb = get_vectorstore( request.api_key, collection_name, request.is_local, request.local_embedding_model) if not vectordb: raise Exception("Failed to initialize vector database") # Add documents in batches with progress updates (40-95%) total_docs = len(documents) docs_processed = 0 batch_size = 100 for i in range(0, len(documents), batch_size): batch = documents[i:i + batch_size] vectordb.add_documents(batch) docs_processed += len(batch) percent = 40 + ((docs_processed / total_docs) * 55) # Progress from 40% to 95% yield {"status": "progress", "data": { "message": f"Embedding chunks in vector database: {docs_processed}/{total_docs}", "chunk": 4, "total_chunks": 4, "percent_complete": f"{percent:.1f}%" }} # Final completion (95-100%) success_msg = f"Successfully processed and stored {chunk_count} transcript chunks. Total length: {sum(len(doc.page_content) for doc in documents)} characters" logger.info(success_msg) yield {"status": "progress", "data": {"message": success_msg, "chunk": 4, "total_chunks": 4, "percent_complete": "100%"}} # Save transcript to file collection_path = _get_collection_path( request.user_id, request.username, request.collection_id, request.collection_name ) if not os.path.exists(collection_path): os.makedirs(collection_path, exist_ok=True) # Create filename using video title and timestamp safe_title = "".join(c for c in info.get( 'title', 'unknown') if c.isalnum() or c in (' ', '-', '_')).rstrip() folder_name = f"{safe_title}_youtube" folder_path = os.path.join(collection_path, folder_name) os.makedirs(folder_path, exist_ok=True) # Save metadata metadata = { "title": info.get('title', ''), "uploader": info.get('uploader', ''), "duration": info.get('duration', ''), "description": info.get('description', ''), "url": request.url } with open(os.path.join(folder_path, "metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata, f, ensure_ascii=False, indent=2) # Save full transcript with open(os.path.join(folder_path, "transcript.txt"), "w", encoding="utf-8") as f: f.write(f"Title: {info.get('title', 'Unknown')}\n") f.write(f"Author: {info.get('uploader', 'Unknown')}\n") f.write(f"Duration: {info.get('duration', 'Unknown')} seconds\n") f.write(f"Source URL: {request.url}\n") f.write("\n--- Transcript ---\n\n") for doc in documents: f.write(f"[{doc.metadata['chunk_start']:.1f}s - {doc.metadata['chunk_end']:.1f}s]\n") f.write(f"{doc.page_content}\n\n") # Save chunked transcripts with timestamps with open(os.path.join(folder_path, "transcript_chunks.json"), "w", encoding="utf-8") as f: chunks = [{ "content": doc.page_content, "start_time": doc.metadata.get("chunk_start", 0), "end_time": doc.metadata.get("chunk_end", 0), "chunk_number": doc.metadata.get("chunk_number", 0) } for doc in documents] json.dump(chunks, f, ensure_ascii=False, indent=2) # Log success logger.info(f"Saved transcript to {folder_path}") return documents except Exception as e: error_msg = f"Error processing YouTube transcript: {str(e)}" logger.error(error_msg, exc_info=True) raise Exception(error_msg) def _time_to_seconds(time_str): """Convert VTT timestamp to seconds""" h, m, s = time_str.split(':') return float(h) * 3600 + float(m) * 60 + float(s) ================================================ FILE: Backend/src/data/dataIntake/csvFallbackSplitting.py ================================================ from langchain_core.documents import Document import pandas as pd import io import time from typing import Generator def split_csv_text(text: str, file_path: str, metadata: dict = None) -> Generator[dict | list, None, None]: """Split CSV text into chunks for embedding while preserving row integrity.""" try: # Convert text back to DataFrame using StringIO yield {"status": "progress", "data": {"message": "Loading CSV data...", "chunk": 1, "total_chunks": 4, "percent_complete": "25%"}} df = pd.read_csv(io.StringIO(text)) # Get headers headers = df.columns.tolist() # Calculate approximate number of rows per chunk (targeting ~2000 characters per chunk) yield {"status": "progress", "data": {"message": "Calculating chunk sizes...", "chunk": 2, "total_chunks": 4, "percent_complete": "50%"}} sample_row = df.iloc[0].to_string(index=False) chars_per_row = len(sample_row) rows_per_chunk = max(1, int(2000 / chars_per_row)) documents = [] total_rows = len(df) start_time = time.time() # Process DataFrame in chunks for i in range(0, total_rows, rows_per_chunk): # Calculate progress progress = min(100, int((i / total_rows) * 100)) elapsed_time = time.time() - start_time est_remaining_time = "calculating..." if i == 0 else f"{(elapsed_time / (i + 1)) * (total_rows - i):.1f}s" yield { "status": "progress", "data": { "message": f"Processing rows {i} to {min(i + rows_per_chunk, total_rows)}...", "chunk": 3, "total_chunks": 4, "percent_complete": f"{progress}%", "est_remaining_time": est_remaining_time } } chunk_df = df.iloc[i:i + rows_per_chunk] # Convert chunk to string more efficiently chunk_text = [] chunk_text.append(",".join(headers)) # Add headers # Convert rows to strings efficiently for _, row in chunk_df.iterrows(): chunk_text.append(",".join(str(val) for val in row)) chunk_content = "\n".join(chunk_text) # Create document with metadata doc_metadata = {"source": file_path, "chunk_start": i} if metadata: doc_metadata.update(metadata) documents.append( Document(page_content=chunk_content, metadata=doc_metadata)) yield {"status": "progress", "data": {"message": "Finalizing chunks...", "chunk": 4, "total_chunks": 4, "percent_complete": "100%"}} print(f"Split CSV into {len(documents)} chunks") return documents except Exception as e: print(f"Error splitting CSV text: {str(e)}") yield {"status": "error", "message": f"Error splitting CSV text: {str(e)}"} return [] ================================================ FILE: Backend/src/data/dataIntake/fileTypes/loadX.py ================================================ import pandas as pd import json import markdown from bs4 import BeautifulSoup from pptx import Presentation from langchain_community.document_loaders import Docx2txtLoader from langchain_community.document_loaders.csv_loader import CSVLoader from pypdf import PdfReader from langchain_core.documents import Document import logging import os import asyncio async def load_pdf(file_path): try: logging.info(f"Starting to load PDF: {file_path}") # Verify file exists and is readable if not os.path.exists(file_path): raise FileNotFoundError(f"PDF file not found: {file_path}") def read_pdf(): reader = PdfReader(file_path) pages = [] for i, page in enumerate(reader.pages): text = page.extract_text() if text.strip(): # Only include pages with content pages.append( Document( page_content=text, metadata={"source": file_path, "page": i} ) ) return pages # Run PDF reading in a thread pool to avoid blocking pages = await asyncio.get_event_loop().run_in_executor(None, read_pdf) if not pages: logging.error(f"No valid pages found in {file_path}") return None logging.info( f"Successfully loaded {len(pages)} pages from {file_path}") logging.info(f"First page metadata: {pages[0].metadata}") logging.info( f"First page content sample: {pages[0].page_content[:200]}...") return pages except Exception as e: logging.error( f"Error loading PDF {file_path}: {str(e)}", exc_info=True) return None async def load_py(file): try: with open(file, 'r', encoding='utf-8') as f: content = f.read() return content.strip() except Exception as e: print(f"Error loading PY: {str(e)}") return None async def load_docx(file): try: loader = Docx2txtLoader(file) data = loader.load() print(data) return data[0].page_content except Exception as e: print(f"Error loading DOCX: {str(e)}") return None async def load_txt(file): try: with open(file, 'r', encoding='utf-8') as f: return f.read().strip() except Exception as e: print(f"Error loading TXT: {str(e)}") return None async def load_md(file): try: with open(file, 'r', encoding='utf-8') as f: md_text = f.read() html = markdown.markdown(md_text) soup = BeautifulSoup(html, 'html.parser') return soup.get_text().strip() except Exception as e: print(f"Error loading MD: {str(e)}") return None async def load_html(file_path: str) -> str: """Load and process HTML file content""" try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # Parse HTML with BeautifulSoup soup = BeautifulSoup(content, 'html.parser') # Remove script and style elements for script in soup(["script", "style"]): script.decompose() # Get text content text = soup.get_text() # Break into lines and remove leading/trailing space lines = (line.strip() for line in text.splitlines()) # Break multi-headlines into a line each chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) # Drop blank lines text = ' '.join(chunk for chunk in chunks if chunk) return text except Exception as e: logging.error(f"Error loading HTML file {file_path}: {str(e)}") return None async def load_csv(file): try: loader = CSVLoader(file) data = loader.load() return data except Exception as e: print(f"Error loading CSV: {str(e)}") return None async def load_json(file): try: with open(file, 'r', encoding='utf-8') as f: data = json.load(f) return json.dumps(data, indent=2) except Exception as e: print(f"Error loading JSON: {str(e)}") return None def load_pptx(file): try: prs = Presentation(file) text = [] for slide in prs.slides: for shape in slide.shapes: if hasattr(shape, "text"): text.append(shape.text) return "\n".join(text).strip() except Exception as e: print(f"Error loading PPTX: {str(e)}") return None def load_xlsx(file): try: df = pd.read_excel(file) return df.to_string().strip() except Exception as e: print(f"Error loading XLSX: {str(e)}") return None async def load_docx(file): try: # Run the synchronous loader in a thread pool to avoid blocking def load_docx_sync(): loader = Docx2txtLoader(file) data = loader.load() return data[0].page_content if data else None content = await asyncio.get_event_loop().run_in_executor(None, load_docx_sync) if content: logging.info(f"Successfully loaded DOCX file: {file}") return content return None except Exception as e: logging.error(f"Error loading DOCX: {str(e)}") return None ================================================ FILE: Backend/src/data/dataIntake/getHtmlFiles.py ================================================ import os def get_html_files(directory): """Recursively get all HTML files in a directory and its subdirectories""" html_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith('.html'): file_path = os.path.join(root, file) html_files.append(file_path) return html_files ================================================ FILE: Backend/src/data/dataIntake/loadFile.py ================================================ import os import logging logger = logging.getLogger(__name__) from src.data.dataIntake.fileTypes.loadX import ( load_csv, load_docx, load_html, load_json, load_md, load_pptx, load_txt, load_xlsx, load_py, load_pdf, ) file_handlers = { "pdf": load_pdf, "docx": load_docx, "txt": load_txt, "md": load_md, "html": load_html, "csv": load_csv, "json": load_json, "pptx": load_pptx, "xlsx": load_xlsx, "py": load_py, } async def load_document(file: str): try: file_type = file.split(".")[-1].lower() logger.info(f"Loading file of type: {file_type}") # Get file size file_size = os.path.getsize(file) logger.info(f"File size: {file_size / (1024*1024):.2f}MB") handler = file_handlers.get(file_type) print(handler) if not handler: logger.error(f"Unsupported file type: {file_type}") return None # Special handling for large PDFs if file_type == "pdf" and file_size > 25 * 1024 * 1024: # 25MB logger.info("Large PDF detected - using chunked processing") return await handler(file, chunk_size=50) # Process 50 pages at a time return await handler(file) except Exception as e: logger.error(f"Error loading file: {str(e)}") return None ================================================ FILE: Backend/src/data/dataIntake/textSplitting.py ================================================ from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document import logging def split_text(text: str, file_path: str, metadata: dict = None) -> list: """Split text into chunks for embedding.""" try: # Handle None or empty text if not text: logging.error(f"Empty or None text received from {file_path}") return [] # Pre-process text to remove excessive whitespace text = " ".join(text.split()) text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=20, length_function=len, is_separator_regex=False, # Prioritize sentence boundaries separators=[". ", "? ", "! ", "\n\n", "\n", " ", ""] ) # Directly split text and create documents in one go texts = text_splitter.split_text(text) # Create metadata if none provided if metadata is None: metadata = {} metadata["source"] = file_path docs = [Document(page_content=t.strip(), metadata=metadata.copy()) for t in texts] if not docs: logging.warning( f"No documents created after splitting text from {file_path}") else: logging.info( f"Successfully split text into {len(docs)} chunks from {file_path}") return docs except Exception as e: logging.error(f"Error splitting text from {file_path}: {str(e)}") return [] ================================================ FILE: Backend/src/data/database/checkAPIKey.py ================================================ from src.data.database.db import db def check_api_key(user_id: int): """ check to see if the userId has API key in SQLite """ print("Checking API key for user:", user_id) try: conn = db() if not conn: print("Failed to connect to database") return False cursor = conn.cursor() # Check for valid, non-expired API key cursor.execute(""" SELECT * FROM dev_api_keys WHERE user_id = ? """, (user_id,)) api_key = cursor.fetchone() conn.close() print(f"API key count for user {user_id}: {api_key}") return api_key is not None except Exception as e: print(f"Error checking API key: {e}") return False ================================================ FILE: Backend/src/data/database/db.py ================================================ import sqlite3 import os import pathlib import platform IS_DEV = os.environ.get("IS_DEV") == "1" def get_user_data_path(): system = platform.system() home = os.path.expanduser("~") if system == "Darwin": # macOS base_path = os.path.join( home, "Library", "Application Support", "notate") elif system == "Windows": base_path = os.path.join(os.getenv("APPDATA"), "notate") else: # Linux and others base_path = os.path.join(home, ".config", "notate") # Add development subdirectory if in dev mode if IS_DEV: return os.path.join(base_path, "development") return base_path def db(): if IS_DEV: try: # Get the absolute path to the project root root_dir = pathlib.Path(__file__).parent.parent.parent.parent db_path = os.path.join(root_dir, "..", 'Database', 'database.sqlite') # Ensure the Database directory exists os.makedirs(os.path.dirname(db_path), exist_ok=True) print(f"Connected to Database at: {db_path}") return sqlite3.connect(db_path) except Exception as e: print(f"Error connecting to database: {e}") return None else: # For production, use the user data directory user_data_path = get_user_data_path() db_dir = os.path.join(user_data_path, "Database") db_path = os.path.join(db_dir, "database.sqlite") # Ensure the Database directory exists os.makedirs(db_dir, exist_ok=True) print(f"Connected to Database at: {db_path}") return sqlite3.connect(db_path) ================================================ FILE: Backend/src/data/database/getCollectionInfo.py ================================================ from src.data.database.db import db from dataclasses import dataclass from typing import Optional @dataclass class CollectionSettings: id: int user_id: int name: str description: str is_local: bool local_embedding_model: Optional[str] type: str files: Optional[str] created_at: str def get_collection_settings(user_id: str, collection_name: str) -> Optional[CollectionSettings]: """ Get collection settings for a specific user and collection name Args: user_id (str): The user ID collection_name (str): The name of the collection Returns: CollectionSettings: Collection settings object or None if not found """ try: conn = db() if not conn: print("Failed to connect to database") return None cursor = conn.cursor() cursor.execute(""" SELECT id, user_id, name, description, is_local, local_embedding_model, type, files, created_at FROM collections WHERE name = ? AND user_id = ? """, (collection_name, user_id)) row = cursor.fetchone() conn.close() if not row: return None return CollectionSettings( id=row[0], user_id=row[1], name=row[2], description=row[3], is_local=bool(row[4]), local_embedding_model=row[5], type=row[6], files=row[7], created_at=row[8] ) except Exception as e: print(f"Error retrieving collection settings: {e}") return None ================================================ FILE: Backend/src/data/database/getLLMApiKey.py ================================================ from src.data.database.db import db def get_llm_api_key(user_id, provider): try: conn = db() cursor = conn.cursor() cursor.execute( "SELECT key FROM api_keys WHERE user_id = ? AND provider = ?", (user_id, provider)) result = cursor.fetchone() conn.close() return result[0] if result else None except Exception as e: print(f"Error retrieving OpenAI API key: {e}") return None ================================================ FILE: Backend/src/endpoint/api.py ================================================ from typing import AsyncGenerator import json from src.endpoint.models import ChatCompletionRequest from transformers import TextIteratorStreamer from threading import Thread import logging from src.models.manager import model_manager from src.models.streamer import TextGenerator, StopOnInterrupt import uuid import time import torch import transformers logger = logging.getLogger(__name__) async def chat_completion_stream(request: ChatCompletionRequest) -> AsyncGenerator[str, None]: """Stream chat completion from the model""" try: model = model_manager.current_model if not model: yield f"data: {json.dumps({'error': 'No model loaded'})}\n\n" return print(request.messages) # Convert messages to prompts try: prompt = "" # Initialize prompt variable # Format messages without explicit User/Assistant markers for msg in request.messages: if msg.role == "system": prompt += f"{msg.content}\n" elif msg.role == "user": prompt += f"Question: {msg.content}\n" elif msg.role == "assistant": prompt += f"Response: {msg.content}\n" prompt += "Response: " logger.info(f"Generated prompt: {prompt}") except Exception as e: logger.error(f"Error formatting prompt: {str(e)}", exc_info=True) raise # Create text generator try: generator = TextGenerator( model, model_manager.current_tokenizer, model_manager.device) # For llama.cpp models, we don't need to pre-encode the input if model_manager.model_type != "llama.cpp": # Only encode for transformers models input_ids = model_manager.current_tokenizer.encode( prompt, return_tensors="pt") attention_mask = torch.ones_like(input_ids) if hasattr(model, "device"): input_ids = input_ids.to(model.device) attention_mask = attention_mask.to(model.device) except Exception as e: logger.error( f"Error setting up generator: {str(e)}", exc_info=True) raise if request.stream: try: # Different handling for llama.cpp vs transformers models if model_manager.model_type == "llama.cpp": # Use the TextGenerator's built-in streaming for llama.cpp stream_iterator = generator.generate( prompt=prompt, max_new_tokens=min(request.max_tokens or 2048, 2048), temperature=request.temperature or 0.7, top_p=request.top_p or 0.95, top_k=request.top_k or 40, repetition_penalty=1.2, stream=True ) async for chunk in stream_iterator: yield chunk yield "data: [DONE]\n\n" else: # Set up generation config for transformers models gen_config = { # Cap at 2048 if not specified "max_new_tokens": min(request.max_tokens or 2048, 2048), "temperature": request.temperature or 0.7, "top_p": request.top_p or 0.95, "top_k": request.top_k or 40, # Slightly lower for more focused sampling "repetition_penalty": 1.2, # Increased to reduce repetition "do_sample": True, "pad_token_id": model_manager.current_tokenizer.pad_token_id, "eos_token_id": model_manager.current_tokenizer.eos_token_id, "no_repeat_ngram_size": 5, # Increased to catch longer repetitive phrases "min_new_tokens": 32, # Increased minimum for more complete thoughts "max_time": 30.0, "stopping_criteria": transformers.StoppingCriteriaList([StopOnInterrupt()]), "forced_eos_token_id": model_manager.current_tokenizer.eos_token_id, "length_penalty": 0.8, # Slight penalty for longer sequences "num_return_sequences": 1, "remove_invalid_values": True } # Add [END] token to the tokenizer's special tokens special_tokens = {"additional_special_tokens": ["[END]"]} model_manager.current_tokenizer.add_special_tokens( special_tokens) logger.info(f"Generation config: {gen_config}") # Create streamer with token-by-token streaming streamer = TextIteratorStreamer( model_manager.current_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=None, # No timeout to prevent queue.Empty errors skip_word_before_colon=False, spaces_between_special_tokens=False, tokenizer_decode_kwargs={"skip_special_tokens": True} ) generation_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, streamer=streamer, **gen_config ) # Create thread for generation thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Generate a consistent ID for this completion completion_id = f"chatcmpl-{uuid.uuid4()}" # Send the initial role message response = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": "local-model", "choices": [{ "index": 0, "delta": {"role": "assistant"}, "finish_reason": None }] } yield f"data: {json.dumps(response)}\n\n" # Stream the output accumulated_text = "" for new_text in streamer: if not new_text: continue # Split into individual characters/tokens for smoother streaming chars = list(new_text) for char in chars: accumulated_text += char response = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": "local-model", "choices": [{ "index": 0, "delta": {"content": char}, "finish_reason": None }] } yield f"data: {json.dumps(response)}\n\n" # Send the final message response = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": "local-model", "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] } yield f"data: {json.dumps(response)}\n\n" yield "data: [DONE]\n\n" except Exception as e: logger.error( f"Error during streaming: {str(e)}", exc_info=True) raise except Exception as e: logger.error(f"Error in chat completion: {str(e)}", exc_info=True) error_response = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "local-model", "choices": [{ "index": 0, "delta": { "content": f"Error: {str(e)}" }, "finish_reason": "error" }] } yield f"data: {json.dumps(error_response)}\n\n" yield "data: [DONE]\n\n" # Make sure to send DONE even on error ================================================ FILE: Backend/src/endpoint/deleteStore.py ================================================ from src.endpoint.models import DeleteCollectionRequest from src.vectorstorage.vectorstore import get_vectorstore import logging logger = logging.getLogger(__name__) def delete_vectorstore_collection(data: DeleteCollectionRequest): try: logger.info(f"Deleting vectorstore collection: {data.collection_name}") vectorstore = get_vectorstore( data.api_key, data.collection_name, data.is_local) if vectorstore: vectorstore.delete_collection() return True return False except Exception as e: logger.error(f"Error deleting vectorstore collection: {str(e)}") return False ================================================ FILE: Backend/src/endpoint/devApiCall.py ================================================ from src.data.database.getCollectionInfo import get_collection_settings from src.data.database.getLLMApiKey import get_llm_api_key from src.endpoint.models import VectorStoreQueryRequest from src.endpoint.ragQuery import rag_query from src.endpoint.vectorQuery import query_vectorstore from src.llms.llmQuery import llm_query from src.endpoint.models import ChatCompletionRequest def vector_call(query_request: VectorStoreQueryRequest, user_id: str): print(f"API vector query received for user {user_id}") if not query_request.model: print(f"No model provided in request body for user {user_id}") """ VECTORSTORE QUERY IF NO MODEL PROVIDED IN REQUEST BODY """ collectionSettings = get_collection_settings( user_id, query_request.collection_name) if collectionSettings.is_local == False: api_key = get_llm_api_key(int(user_id), "openai") else: api_key = None if not collectionSettings: raise ValueError("Collection settings not found") vectorStoreData = VectorStoreQueryRequest( query=query_request.input, collection=collectionSettings.id, collection_name=query_request.collection_name, user=user_id, api_key=api_key, top_k=query_request.top_k, is_local=collectionSettings.is_local, local_embedding_model=collectionSettings.local_embedding_model ) return query_vectorstore(vectorStoreData, collectionSettings.is_local) async def rag_call(query_request: VectorStoreQueryRequest, user_id: str): print(f"Model provided in request body for user {user_id}") """ MODEL + VECTORSTORE QUERY IF MODEL AND COLLECTION NAME PROVIDED IN REQUEST BODY """ collectionSettings = get_collection_settings( user_id, query_request.collection_name) if not collectionSettings: raise ValueError("Collection settings not found") if query_request.is_local == False: api_key = get_llm_api_key(int(user_id), query_request.provider) else: api_key = None ragData = VectorStoreQueryRequest( query=query_request.input, collection=collectionSettings.id, collection_name=query_request.collection_name, user=user_id, api_key=api_key, top_k=query_request.top_k, is_local=collectionSettings.is_local, local_embedding_model=collectionSettings.local_embedding_model, temperature=query_request.temperature, max_completion_tokens=query_request.max_completion_tokens, top_p=query_request.top_p, frequency_penalty=query_request.frequency_penalty, presence_penalty=query_request.presence_penalty, provider=query_request.provider, model=query_request.model, is_ooba=query_request.is_ooba ) return await rag_query(ragData, collectionSettings) async def llm_call(query_request: ChatCompletionRequest, user_id: str): print( f"Model and collection name provided in request body for user {user_id}") """ MODEL QUERY IF MODEL BUT NO COLLECTION NAME PROVIDED IN REQUEST BODY """ if query_request.is_local == False: api_key = get_llm_api_key(int(user_id), query_request.provider) else: api_key = None return await llm_query(query_request, api_key) ================================================ FILE: Backend/src/endpoint/embed.py ================================================ from src.data.dataIntake.textSplitting import split_text from src.data.dataIntake.loadFile import load_document from src.endpoint.models import EmbeddingRequest from src.vectorstorage.helpers.sanitizeCollectionName import sanitize_collection_name from src.vectorstorage.vectorstore import get_vectorstore from src.vectorstorage.embeddings import embed_chunk, chunk_list import os import multiprocessing import concurrent.futures import time from typing import AsyncGenerator from collections import deque import logging logger = logging.getLogger(__name__) async def embed(data: EmbeddingRequest) -> AsyncGenerator[dict, None]: file_name = os.path.basename(data.file_path) try: yield {"status": "info", "message": f"Starting embedding process for file: {file_name}"} # Get file size file_size = os.path.getsize(data.file_path) if file_size > 25 * 1024 * 1024: # If file is larger than 25MB yield {"status": "info", "message": f"Processing large file ({file_size / (1024*1024):.1f}MB). This may take longer."} text_output = await load_document(data.file_path) if text_output is None: raise Exception("Failed to load document") # Handle generator output from CSV loader if hasattr(text_output, '__iter__') and not isinstance(text_output, (str, list)): texts = [] for item in text_output: if isinstance(item, dict) and "status" in item: # Forward progress updates from CSV processing yield item else: texts = item else: yield {"status": "info", "message": "File loaded successfully"} # Check if file is CSV or PDF if file_name.lower().endswith('.csv'): texts = text_output # CSV loader already returns list of documents elif file_name.lower().endswith('.pdf'): # PDF loader returns list of Documents, no need to split texts = text_output else: # Pass metadata to split_text if it exists texts = split_text(text_output, data.file_path, data.metadata if hasattr(data, 'metadata') else None) if not texts: raise Exception("No text content extracted from file") yield {"status": "info", "message": f"Split text into {len(texts)} chunks"} collection_name = sanitize_collection_name(str(data.collection_name)) vectordb = get_vectorstore( data.api_key, collection_name, data.is_local, data.local_embedding_model) if not vectordb: raise Exception("Failed to initialize vector database") # Adjust chunk size based on file size chunk_size = min(50, max(10, int(1000000 / file_size))) # Dynamic chunk size chunks = list(chunk_list(texts, chunk_size)) total_chunks = len(chunks) yield {"status": "info", "message": f"Split into {total_chunks} chunks of {chunk_size} documents each"} start_time = time.time() time_history = deque(maxlen=5) # Process chunks with reduced parallelism for large files num_cores = max(1, min(multiprocessing.cpu_count() - 1, 4)) # Use fewer cores for large files yield {"status": "info", "message": f"Using {num_cores} CPU cores for processing"} with concurrent.futures.ThreadPoolExecutor(max_workers=num_cores) as executor: futures = [] for i, chunk in enumerate(chunks): chunk_arg = (vectordb, chunk, i + 1, total_chunks, start_time, time_history) future = executor.submit(embed_chunk, chunk_arg) futures.append(future) # Process results as they complete for completed in concurrent.futures.as_completed(futures): try: result = completed.result() yield {"status": "progress", "data": result} except Exception as e: logger.error(f"Error processing chunk: {str(e)}") yield {"status": "error", "message": f"Error processing chunk: {str(e)}"} futures = [f for f in futures if not f.done()] # Clean up completed futures yield {"status": "success", "message": "Embedding completed successfully"} except Exception as e: error_msg = f"Error embedding file: {str(e)}" logger.error(error_msg) yield {"status": "error", "message": error_msg} ================================================ FILE: Backend/src/endpoint/models.py ================================================ from pydantic import BaseModel from typing import Optional, Dict, Any, List, Literal class EmbeddingRequest(BaseModel): file_path: str api_key: Optional[str] = None collection: int collection_name: str user: int metadata: Optional[Dict[str, Any]] = None is_local: Optional[bool] = False local_embedding_model: Optional[str] = "granite-embedding:278m" class ModelLoadRequest(BaseModel): model_name: str model_type: Optional[str] = "auto" # 'auto', 'Transformers', 'llama.cpp', 'llamacpp_HF', 'ExLlamav2', 'ExLlamav2_HF', 'HQQ', 'TensorRT-LLM' device: Optional[str] = "auto" # 'cpu', 'cuda', 'auto' # Transformers specific settings load_in_8bit: Optional[bool] = False load_in_4bit: Optional[bool] = False use_flash_attention: Optional[bool] = False trust_remote_code: Optional[bool] = True use_safetensors: Optional[bool] = True max_memory: Optional[Dict[str, str]] = None compute_dtype: Optional[str] = "float16" # float16, bfloat16, float32 rope_scaling: Optional[Dict[str, Any]] = None use_cache: Optional[bool] = True revision: Optional[str] = None padding_side: Optional[str] = "right" use_fast_tokenizer: Optional[bool] = True hf_token: Optional[str] = None # HuggingFace token for gated models # ExLlamav2 specific settings max_seq_len: Optional[int] = None compress_pos_emb: Optional[float] = 1.0 alpha_value: Optional[float] = 1 # llama.cpp specific settings n_ctx: Optional[int] = 2048 n_batch: Optional[int] = 512 n_threads: Optional[int] = None n_threads_batch: Optional[int] = None n_gpu_layers: Optional[int] = 32 main_gpu: Optional[int] = 0 tensor_split: Optional[List[float]] = None mul_mat_q: Optional[bool] = True use_mmap: Optional[bool] = True use_mlock: Optional[bool] = False offload_kqv: Optional[bool] = False split_mode: Optional[str] = None flash_attn: Optional[bool] = False cache_type: Optional[str] = None cache_size: Optional[int] = None rope_scaling_type: Optional[str] = None rope_freq_base: Optional[float] = None rope_freq_scale: Optional[float] = None # HQQ specific settings hqq_backend: Optional[str] = "PYTORCH_COMPILE" # PYTORCH_COMPILE, ATEN, TENSORRT # TensorRT-LLM specific settings engine_dir: Optional[str] = None max_batch_size: Optional[int] = 1 max_input_len: Optional[int] = 2048 max_output_len: Optional[int] = 512 # Common settings model_path: Optional[str] = None # Custom path to model files if not in default location tokenizer_path: Optional[str] = None # Custom path to tokenizer if different from model path class Config: protected_namespaces = () class VectorStoreQueryRequest(BaseModel): query: str collection: Optional[int] = None collection_name: str user: int api_key: Optional[str] = None top_k: int = 5 is_local: Optional[bool] = False local_embedding_model: Optional[str] = "granite-embedding:278m" prompt: Optional[str] = None provider: Optional[str] = None model: Optional[str] = None temperature: Optional[float] = 0.5 max_completion_tokens: Optional[int] = 2048 top_p: Optional[float] = 1 frequency_penalty: Optional[float] = 0 presence_penalty: Optional[float] = 0 is_ooba: Optional[bool] = False character: Optional[str] = None is_ollama: Optional[bool] = False class YoutubeTranscriptRequest(BaseModel): url: str user_id: int collection_id: int username: str collection_name: str api_key: Optional[str] = None is_local: Optional[bool] = False local_embedding_model: Optional[str] = "granite-embedding:278m" class DeleteCollectionRequest(BaseModel): collection_id: int collection_name: str is_local: Optional[bool] = False api_key: Optional[str] = None class WebCrawlRequest(BaseModel): base_url: str max_workers: int collection_name: str collection_id: int user_id: int user_name: str api_key: Optional[str] = None is_local: Optional[bool] = False local_embedding_model: Optional[str] = "granite-embedding:278m" class QueryRequest(BaseModel): input: str prompt: Optional[str] = None provider: Optional[str] = None model: Optional[str] = None collection_name: Optional[str] = None top_k: Optional[int] = 5 temperature: Optional[float] = 0.5 max_completion_tokens: Optional[int] = 2048 top_p: Optional[float] = 1 frequency_penalty: Optional[float] = 0 presence_penalty: Optional[float] = 0 is_local: Optional[bool] = False is_ooba: Optional[bool] = False local_embedding_model: Optional[str] = "granite-embedding:278m" character: Optional[str] = None is_ollama: Optional[bool] = False class Message(BaseModel): """A single message in a chat completion request""" role: Literal["system", "user", "assistant"] content: str name: Optional[str] = None class ChatCompletionRequest(BaseModel): """Request model for chat completion""" messages: List[Message] model: str = "local-model" temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 top_k: Optional[int] = 50 n: Optional[int] = 1 max_tokens: Optional[int] = 2048 presence_penalty: Optional[float] = 0.1 frequency_penalty: Optional[float] = 0.1 repetition_penalty: Optional[float] = 1.1 stop: Optional[List[str]] = None stream: Optional[bool] = True is_local: Optional[bool] = False is_ooba: Optional[bool] = False is_ollama: Optional[bool] = False class GenerateRequest(BaseModel): """Request model for raw text generation""" prompt: str max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 top_k: Optional[int] = 50 repetition_penalty: Optional[float] = 1.1 stop_sequences: Optional[List[str]] = None echo: Optional[bool] = False stream: Optional[bool] = True ================================================ FILE: Backend/src/endpoint/ragQuery.py ================================================ from src.endpoint.models import VectorStoreQueryRequest, ChatCompletionRequest from src.endpoint.vectorQuery import query_vectorstore from src.llms.llmQuery import llm_query async def rag_query(data: VectorStoreQueryRequest, collectionInfo): try: results = query_vectorstore(data, data.is_local) data.prompt = f"The following is the data that the user has provided via their custom data collection: " + \ f"\n\n{results}" + \ f"\n\nCollection/Store Name: {collectionInfo.name}" + \ f"\n\nCollection/Store Files: {collectionInfo.files}" + \ f"\n\nCollection/Store Description: {collectionInfo.description}" chat_completion_request = ChatCompletionRequest( messages=[ { "role": "system", "content": data.prompt }, { "role": "user", "content": data.query } ], model=data.model, temperature=data.temperature, max_completion_tokens=data.max_completion_tokens, top_p=data.top_p, frequency_penalty=data.frequency_penalty, presence_penalty=data.presence_penalty, provider=data.provider, is_local=data.is_local ) llm_response = await llm_query(chat_completion_request, data.api_key) return llm_response except Exception as e: print(e) raise e ================================================ FILE: Backend/src/endpoint/transcribe.py ================================================ from src.voice.voice_to_text import initialize_model import os import tempfile from fastapi import UploadFile, File, HTTPException # Global variables model = None ffmpeg_path = None async def transcribe_audio(audio_file: UploadFile = File(...), model_name: str = "base") -> dict: """Transcribe audio using Whisper.""" temp_file = None try: # Initialize model and verify FFmpeg is available model = initialize_model(model_name) if not model: raise HTTPException( status_code=500, detail="FFmpeg not found or not working") # Create temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") content = await audio_file.read() temp_file.write(content) temp_file.flush() temp_file.close() result = model.transcribe(temp_file.name) return { "status": "success", "text": result["text"], "language": result.get("language", "unknown"), "segments": result.get("segments", []) } except Exception as e: print(f"Error transcribing audio: {str(e)}") return { "status": "error", "error": str(e) } finally: if temp_file and os.path.exists(temp_file.name): try: os.unlink(temp_file.name) print(f"Deleted temporary file: {temp_file.name}") except Exception as e: print( f"Warning: Could not delete temporary file {temp_file.name}: {str(e)}") ================================================ FILE: Backend/src/endpoint/vectorQuery.py ================================================ from src.endpoint.models import VectorStoreQueryRequest from src.vectorstorage.helpers.sanitizeCollectionName import sanitize_collection_name from src.vectorstorage.vectorstore import get_vectorstore def query_vectorstore(data: VectorStoreQueryRequest, is_local: bool): try: collection_name = sanitize_collection_name(str(data.collection_name)) vectordb = get_vectorstore( data.api_key, collection_name, is_local, data.local_embedding_model) results = vectordb.similarity_search(data.query, k=data.top_k) return { "status": "success", "results": [{"content": doc.page_content, "metadata": doc.metadata} for doc in results], } except Exception as e: print(f"Error querying vectorstore: {str(e)}") return {"status": "error", "message": str(e)} ================================================ FILE: Backend/src/endpoint/webcrawl.py ================================================ from src.data.dataIntake.fileTypes.loadX import load_html from src.data.dataIntake.textSplitting import split_text from src.data.dataIntake.getHtmlFiles import get_html_files from src.data.dataFetch.webcrawler import WebCrawler from src.endpoint.models import WebCrawlRequest from src.vectorstorage.vectorstore import get_vectorstore from typing import Generator import json import os from urllib.parse import urlparse import logging def webcrawl(data: WebCrawlRequest, cancel_event=None) -> Generator[dict, None, None]: try: # Create web crawler instance with all required fields scraper = WebCrawler( data.base_url, data.user_id, data.user_name, data.collection_id, data.collection_name, max_workers=data.max_workers, cancel_event=cancel_event ) # Yield progress updates during scraping for progress in scraper.scrape(): if progress: yield f"data: {json.dumps(progress)}" # After scraping, process and embed all HTML files root_url_dir = urlparse( data.base_url).netloc.replace(".", "_") + "_docs" collection_path = os.path.join(scraper.output_dir, root_url_dir) vector_store = get_vectorstore( data.api_key, data.collection_name, data.is_local, data.local_embedding_model) # Get all HTML files recursively html_files = get_html_files(collection_path) print(f"Found {len(html_files)} HTML files") # Process files in batches for better performance batch_size = 50 total_batches = (len(html_files) + batch_size - 1) // batch_size for i in range(0, len(html_files), batch_size): batch = html_files[i:i + batch_size] batch_docs = [] for file_path in batch: content = load_html(file_path) if content: split_content = split_text(content, file_path) batch_docs.extend(split_content) if batch_docs: vector_store.add_documents(batch_docs) current_batch = i//batch_size + 1 progress_data = { "status": "progress", "data": { "message": f"Part 2 of 2: Processing documents batch {current_batch}/{total_batches}", "chunk": current_batch, "total_chunks": total_batches, "percent_complete": f"{(current_batch/total_batches * 100):.1f}%" } } yield f"data: {json.dumps(progress_data)}" final_message = f"Successfully crawled and embedded {len(scraper.visited_urls)} pages from {data.base_url}" success_data = { "status": "success", "data": { "message": final_message } } yield f"data: {json.dumps(success_data)}" except Exception as e: error_message = str(e) print(f"Error during webcrawl: {error_message}") logging.error(f"Error during webcrawl: {error_message}") error_data = { "status": "error", "data": { "message": error_message } } yield f"data: {json.dumps(error_data)}" ================================================ FILE: Backend/src/llms/llmQuery.py ================================================ from src.endpoint.models import ChatCompletionRequest from src.llms.providers.ooba import ooba_query from src.llms.providers.openai import openai_query from src.llms.providers.ollama import ollama_query from src.llms.providers.local import local_query from typing import Optional async def llm_query(data: ChatCompletionRequest, api_key: Optional[str] = None): try: if data.is_ooba: return ooba_query(data, data.messages) elif data.is_ollama is None: return ollama_query(data, data.messages) elif data.is_local: return await local_query(data) else: return openai_query(data, api_key, data.messages) except Exception as e: print(f"Error in llm_query: {str(e)}") raise e ================================================ FILE: Backend/src/llms/messages/formMessages.py ================================================ from src.endpoint.models import QueryRequest def form_messages(data: QueryRequest): try: if not data.prompt: raise ValueError("System prompt cannot be null") query_content = data.query if hasattr( data, 'query') else data.input if not query_content: raise ValueError("User query/input cannot be null") messages = [ {"role": "system", "content": data.prompt}, {"role": "user", "content": query_content} ] return messages except Exception as e: print(f"Error in form_messages: {str(e)}") raise e ================================================ FILE: Backend/src/llms/providers/local.py ================================================ import asyncio import json import time import logging from src.endpoint.api import chat_completion_stream from src.endpoint.models import ChatCompletionRequest, ModelLoadRequest from src.models.manager import model_manager from src.models.exceptions import ModelLoadError logger = logging.getLogger(__name__) async def local_query(data: ChatCompletionRequest): try: # Check if model is loaded and load it if necessary if not model_manager.is_model_loaded() or model_manager.model_name != data.model: logger.info(f"Loading model {data.model} as it is not currently loaded") # Create model load request load_request = ModelLoadRequest( model_name=data.model, model_type="Transformers", # Default to Transformers for now device="auto", trust_remote_code=True, use_safetensors=True, compute_dtype="float16" ) try: # Load the model model_manager.load_model(load_request) logger.info(f"Successfully loaded model {data.model}") except ModelLoadError as e: logger.error(f"Failed to load model {data.model}: {str(e)}") raise # Get the generator response_gen = chat_completion_stream(data) combined_content = "" response_id = None finish_reason = None # Process each chunk async for chunk in response_gen: if chunk.startswith("data: "): chunk = chunk[6:] # Remove "data: " prefix if chunk.strip() == "[DONE]": continue try: chunk_data = json.loads(chunk) if "choices" in chunk_data and len(chunk_data["choices"]) > 0: choice = chunk_data["choices"][0] if "delta" in choice: delta = choice["delta"] if "content" in delta: combined_content += delta["content"] if "finish_reason" in choice and choice["finish_reason"]: finish_reason = choice["finish_reason"] if not response_id: response_id = chunk_data.get("id") except json.JSONDecodeError as e: logger.warning(f"Failed to parse chunk as JSON: {str(e)}") continue # Create final response structure response = { "id": response_id or f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": data.model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": combined_content }, "finish_reason": finish_reason or "stop" }] } return response except Exception as e: logger.error(f"Error in local_query: {str(e)}", exc_info=True) raise ================================================ FILE: Backend/src/llms/providers/ollama.py ================================================ from src.endpoint.models import QueryRequest import requests import json import time def ollama_query(data: QueryRequest, messages: list = None): try: print("Local Ollama model enabled") model_data = { "model": data.model, "messages": messages, "stream": False, # Disable streaming for now "keep_alive": -1, "max_tokens": data.max_completion_tokens, "keep_alive": -1, } print(f"Model data: {model_data}") response = requests.post( "http://localhost:11434/api/chat", json=model_data) print(f"Raw response: {response.text}") if response.status_code == 200: try: response_json = response.json() print(f"Parsed response: {response_json}") # Extract content from the nested message structure content = response_json.get("message", {}).get( "content", "No response from model") # Standardized response format return { "id": f"local-{data.model}-{int(time.time())}", "choices": [{ "finish_reason": "stop", "index": 0, "message": { "content": content, "role": "assistant" } }], "created": int(time.time()), "model": data.model, "object": "chat.completion", "usage": { "completion_tokens": -1, # Token count not available for local models "prompt_tokens": -1, "total_tokens": -1 } } except json.JSONDecodeError as e: print(f"JSON decode error: {e}") raise ValueError( f"Failed to parse response from Ollama: {e}") return ollama_query(data) except Exception as e: print(f"Error in ollama_query: {str(e)}") raise e ================================================ FILE: Backend/src/llms/providers/ooba.py ================================================ from src.endpoint.models import QueryRequest import requests def ooba_query(data: QueryRequest, messages: list = None): try: print("Ooba mode enabled") ooba_data = { "messages": messages, "mode": "chat", "character": data.character } response = requests.post( "http://127.0.0.1:5000/v1/chat/completions", json=ooba_data) return response.json() except Exception as e: print(f"Error in ooba_query: {str(e)}") raise e ================================================ FILE: Backend/src/llms/providers/openai.py ================================================ from src.endpoint.models import QueryRequest from openai import OpenAI from typing import Optional def openai_query(data: QueryRequest, api_key: Optional[str] = None, messages: list = None): try: print(f"API key3: {api_key}") client = OpenAI(api_key=api_key) response = client.chat.completions.create( model=data.model, messages=messages, response_format={ "type": "text" }, temperature=data.temperature, max_completion_tokens=data.max_completion_tokens, top_p=data.top_p, frequency_penalty=data.frequency_penalty, presence_penalty=data.presence_penalty ) # Convert OpenAI response to dict for consistent format return response.model_dump() except Exception as e: print(f"Error in openai_query: {str(e)}") raise e ================================================ FILE: Backend/src/models/__init__.py ================================================ ================================================ FILE: Backend/src/models/exceptions.py ================================================ class ModelLoadError(Exception): """Exception raised when there is an error loading a model.""" pass class ModelNotFoundError(Exception): """Exception raised when a requested model cannot be found.""" pass class ModelDownloadError(Exception): """Exception raised when there is an error downloading a model.""" pass ================================================ FILE: Backend/src/models/loaders/__init__.py ================================================ from .transformers import TransformersLoader from .llamacpp import LlamaCppLoader from .llamaccphf import LlamaCppHFLoader from .exllama import ExLlamaV2Loader, ExLlamaV2HFLoader from .hqq import HQQLoader from .tensorrt import TensorRTLoader __all__ = [ 'TransformersLoader', 'LlamaCppLoader', 'LlamaCppHFLoader', 'ExLlamaV2Loader', 'ExLlamaV2HFLoader', 'HQQLoader', 'TensorRTLoader', ] ================================================ FILE: Backend/src/models/loaders/base.py ================================================ from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, Optional, Tuple import logging from dataclasses import asdict from src.endpoint.models import ModelLoadRequest from src.models.exceptions import ModelLoadError logger = logging.getLogger(__name__) class BaseLoader(ABC): """ Abstract base class for model loaders. This class defines the interface that all model loaders must implement and provides some common utility methods. Attributes: request (ModelLoadRequest): The request object containing loading parameters manager (Any): Reference to the model manager instance model_path (Path): Path to the model files """ def __init__(self, request: ModelLoadRequest, manager: Any): """ Initialize the loader with request parameters and manager reference. Args: request: ModelLoadRequest object containing all loading parameters manager: Reference to the ModelManager instance """ self.request = request self.manager = manager self.model_path = self._resolve_model_path() @abstractmethod def load(self) -> Tuple[Any, Any]: """ Load the model and tokenizer. Returns: Tuple containing (model, tokenizer) Raises: ModelLoadError: If there's an error during model loading """ pass @abstractmethod def get_metadata(self) -> Optional[Dict[str, Any]]: """ Get model metadata without loading the full model. Returns: Dictionary containing model metadata or None if not available """ pass @abstractmethod def get_config(self) -> Dict[str, Any]: """ Get the current model configuration. Returns: Dictionary containing model configuration """ pass def _resolve_model_path(self) -> Path: """ Resolve the model path from the request parameters. Returns: Path object pointing to the model location Raises: ModelLoadError: If the path cannot be resolved """ try: if self.request.model_path: path = Path(self.request.model_path) else: path = Path(f"models/{self.request.model_name}") # Create parent directories if they don't exist path.parent.mkdir(parents=True, exist_ok=True) return path except Exception as e: raise ModelLoadError(f"Failed to resolve model path: {str(e)}") def get_request_dict(self) -> Dict[str, Any]: """ Convert the request object to a dictionary, filtering out None values. Returns: Dictionary containing all non-None request parameters """ return {k: v for k, v in asdict(self.request).items() if v is not None} def log_loading_info(self) -> None: """Log information about the model being loaded.""" logger.info(f"Loading model: {self.request.model_name}") logger.info(f"Model type: {self.request.model_type}") logger.info(f"Model path: {self.model_path}") logger.info(f"Device: {self.request.device}") @staticmethod def cleanup(model: Any) -> None: """ Clean up model resources. Args: model: The model instance to clean up """ try: if hasattr(model, 'cpu'): model.cpu() del model except Exception as e: logger.warning(f"Error during model cleanup: {str(e)}") def validate_model_path(self) -> None: """ Validate that the model path exists and is accessible. Raises: ModelLoadError: If the model path is invalid or inaccessible """ if not self.model_path.exists(): raise ModelLoadError( f"Model path does not exist: {self.model_path}") def get_common_metadata(self) -> Dict[str, Any]: """ Get common metadata that applies to all model types. Returns: Dictionary containing common metadata fields """ return { "model_name": self.request.model_name, "model_type": self.request.model_type, "model_path": str(self.model_path), "device": self.request.device, "file_size": self.model_path.stat().st_size if self.model_path.exists() else None, } def validate_request(self) -> None: """ Validate the model load request parameters. Raises: ModelLoadError: If the request parameters are invalid """ if not self.request.model_name: raise ModelLoadError("Model name is required") if not self.request.model_type: raise ModelLoadError("Model type is required") def check_dependencies(self) -> None: """ Check if all required dependencies are installed. Raises: ModelLoadError: If any required dependency is missing """ pass # Implement in specific loaders def prepare_loading(self) -> None: """ Prepare for model loading by performing all necessary checks. This method combines several validation steps and should be called at the start of the load method in implementing classes. Raises: ModelLoadError: If any preparation step fails """ try: self.validate_request() self.check_dependencies() self.validate_model_path() self.log_loading_info() except Exception as e: raise ModelLoadError( f"Failed to prepare for model loading: {str(e)}") def get_device_config(self) -> Dict[str, Any]: """ Get device-specific configuration. Returns: Dictionary containing device configuration """ import torch return { "device": self.request.device, "cuda_available": torch.cuda.is_available(), "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(), } def get_memory_info(self) -> Dict[str, Any]: """ Get system memory information. Returns: Dictionary containing memory information """ try: import psutil vm = psutil.virtual_memory() return { "total_memory": vm.total, "available_memory": vm.available, "memory_percent": vm.percent, } except ImportError: return {} def get_system_info(self) -> Dict[str, Any]: """ Get system information. Returns: Dictionary containing system information """ import platform return { "platform": platform.system(), "platform_release": platform.release(), "python_version": platform.python_version(), "device_config": self.get_device_config(), "memory_info": self.get_memory_info(), } def log_error(self, error: Exception, context: str = "") -> None: """ Log an error with context. Args: error: The exception that occurred context: Additional context about where/why the error occurred """ error_msg = f"{context + ': ' if context else ''}{str(error)}" logger.error(error_msg, exc_info=True) def __repr__(self) -> str: """ Get string representation of the loader. Returns: String representation including model name and type """ return f"{self.__class__.__name__}(model_name={self.request.model_name}, model_type={self.request.model_type})" ================================================ FILE: Backend/src/models/loaders/exllama.py ================================================ import logging from typing import Any, Dict, Optional, Tuple from src.models.loaders.base import BaseLoader from src.models.exceptions import ModelLoadError from transformers import AutoTokenizer logger = logging.getLogger(__name__) class ExLlamaV2Loader(BaseLoader): """Loader for ExLlamaV2 models.""" def load(self) -> Tuple[Any, Any]: """Load an ExLlamav2 model.""" try: from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer import torch except ImportError: raise ModelLoadError( "exllamav2 is not installed. Please install it from the ExLlamaV2 repository") if not self.model_path.exists(): raise ModelLoadError( f"Model path does not exist: {self.model_path}") # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"CUDA Device: {torch.cuda.get_device_name(0)}") logger.info( f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.0f}MB") if not torch.cuda.is_available(): raise ModelLoadError("GPU is required for ExLlama2") # Force CUDA device torch.set_default_device('cuda') torch.set_default_tensor_type('torch.cuda.FloatTensor') config = ExLlamaV2Config() config.model_dir = str(self.model_path) config.max_seq_len = self.request.max_seq_len or 2048 config.compress_pos_emb = self.request.compress_pos_emb config.alpha_value = self.request.alpha_value config.calculate_rotary_embedding_base() # Important for GPU performance logger.info(f"Loading model with config: {config.__dict__}") model = ExLlamaV2(config) # Force model to GPU model.load() for param in model.parameters(): param.data = param.data.cuda() logger.info( f"Model loaded on GPU. CUDA Memory: {torch.cuda.memory_allocated() / 1024**2:.0f}MB") logger.info( f"Device for first parameter: {next(model.parameters()).device}") tokenizer = ExLlamaV2Tokenizer(config) logger.info("Model and tokenizer loaded successfully") return model, tokenizer def get_metadata(self) -> Optional[Dict[str, Any]]: """Get model metadata.""" if not self.model_path.exists(): return None return { "model_type": "ExLlamav2", "model_path": str(self.model_path), "file_size": self.model_path.stat().st_size } def get_config(self) -> Dict[str, Any]: """Get model configuration.""" return { "model_type": "ExLlamav2", "model_name": self.request.model_name, "device": self.request.device, "max_seq_len": self.request.max_seq_len, "compress_pos_emb": self.request.compress_pos_emb, "alpha_value": self.request.alpha_value } class ExLlamaV2HFLoader(BaseLoader): """Loader for ExLlamaV2 models with HuggingFace tokenizer.""" def load(self) -> Tuple[Any, Any]: """Load an ExLlamav2 model with HF tokenizer.""" model = ExLlamaV2Loader(self.request, self.manager).load()[0] tokenizer_path = self.request.tokenizer_path or self.model_path tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=self.request.trust_remote_code, use_fast=self.request.use_fast_tokenizer, ) return model, tokenizer def get_metadata(self) -> Optional[Dict[str, Any]]: """Get model metadata.""" return ExLlamaV2Loader(self.request, self.manager).get_metadata() def get_config(self) -> Dict[str, Any]: """Get model configuration.""" return ExLlamaV2Loader(self.request, self.manager).get_config() ================================================ FILE: Backend/src/models/loaders/hqq.py ================================================ import logging from typing import Any, Dict, Optional, Tuple import requests from tqdm import tqdm from src.models.loaders.base import BaseLoader from src.models.exceptions import ModelLoadError, ModelDownloadError from transformers import AutoTokenizer logger = logging.getLogger(__name__) class HQQLoader(BaseLoader): """Loader for HQQ quantized models.""" def load(self) -> Tuple[Any, Any]: """Load an HQQ model.""" try: from hqq.core.quantize import HQQBackend, HQQLinear from hqq.models.hf.base import AutoHQQHFModel except ImportError: raise ModelLoadError( "hqq is not installed. Please install it from the HQQ repository") try: # Create models directory if it doesn't exist self.model_path.parent.mkdir(parents=True, exist_ok=True) logger.info(f"Using model path: {self.model_path}") # If it's a HuggingFace model ID and doesn't exist locally, try to download it if '/' in self.request.model_name and not self.model_path.exists(): self._download_model() if not self.model_path.exists(): raise ModelLoadError( f"Model path does not exist: {self.model_path}") logger.info(f"Loading HQQ model from {self.model_path}") model = AutoHQQHFModel.from_quantized(str(self.model_path)) logger.info("Model loaded successfully") logger.info(f"Setting HQQ backend to {self.request.hqq_backend}") HQQLinear.set_backend( getattr(HQQBackend, self.request.hqq_backend)) logger.info("HQQ backend set successfully") logger.info("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained( self.request.tokenizer_path or self.model_path, trust_remote_code=self.request.trust_remote_code, use_fast=self.request.use_fast_tokenizer, ) logger.info("Tokenizer loaded successfully") return model, tokenizer except Exception as e: raise ModelLoadError(f"Failed to load HQQ model: {str(e)}") def _download_model(self) -> None: """Download model from HuggingFace.""" try: # Get repository contents api_url = f"https://huggingface.co/api/models/{self.request.model_name}/tree/main" headers = {"Accept": "application/json"} if self.request.hf_token: headers["Authorization"] = f"Bearer {self.request.hf_token}" logger.info(f"Fetching repository contents from {api_url}") response = requests.get(api_url, headers=headers) response.raise_for_status() files = response.json() logger.info(f"Found {len(files)} files in repository") # Required files for HQQ models required_files = ['qmodel.pt', 'config.json', 'tokenizer.model', 'tokenizer_config.json', 'tokenizer.json'] logger.info(f"Required files: {required_files}") # Download each required file for file_name in required_files: file_info = next( (f for f in files if f['path'] == file_name), None) if not file_info: logger.error( f"Required file {file_name} not found in repository. Available files: {[f['path'] for f in files]}") raise ModelDownloadError( f"Required file {file_name} not found in repository {self.request.model_name}") download_url = f"https://huggingface.co/{self.request.model_name}/resolve/main/{file_name}" file_path = self.model_path / file_name # Download the file with progress bar logger.info( f"Downloading {file_name} ({file_info.get('size', 'unknown size')}) from {download_url}") response = requests.get( download_url, stream=True, headers=headers) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) block_size = 8192 # 8 KB with open(file_path, 'wb') as f, tqdm( desc=file_name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(block_size): size = f.write(data) pbar.update(size) logger.info( f"Successfully downloaded {file_name} to {file_path}") except Exception as e: logger.error( f"Failed to download model: {str(e)}", exc_info=True) # Clean up any partially downloaded files if self.model_path.exists(): import shutil shutil.rmtree(self.model_path) raise ModelDownloadError(f"Failed to download model: {str(e)}") def get_metadata(self) -> Optional[Dict[str, Any]]: """Get model metadata.""" if not self.model_path.exists(): return None return { "model_type": "HQQ", "model_path": str(self.model_path), "file_size": self.model_path.stat().st_size, "backend": self.request.hqq_backend } def get_config(self) -> Dict[str, Any]: """Get model configuration.""" return { "model_type": "HQQ", "model_name": self.request.model_name, "device": self.request.device, "backend": self.request.hqq_backend } ================================================ FILE: Backend/src/models/loaders/llamaccphf.py ================================================ from typing import Any, Tuple from src.models.loaders.llamacpp import LlamaCppLoader class LlamaCppHFLoader(LlamaCppLoader): """ Loader for llama.cpp models with HuggingFace tokenizer. Inherits from LlamaCppLoader but uses a separate HF tokenizer. """ def load(self) -> Tuple[Any, Any]: """Load model with HuggingFace tokenizer.""" from transformers import AutoTokenizer # Load the base model model, _ = super().load() # Load HuggingFace tokenizer tokenizer_path = self.request.tokenizer_path or ( self.request.model_path if self.request.model_path else f"models/{self.request.model_name}") tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=self.request.trust_remote_code, use_fast=self.request.use_fast_tokenizer, ) return model, tokenizer ================================================ FILE: Backend/src/models/loaders/llamacpp.py ================================================ import os import logging import requests from pathlib import Path from typing import Any, Dict, Optional, Tuple from tqdm import tqdm import sys from src.models.loaders.base import BaseLoader from src.endpoint.models import ModelLoadRequest from src.models.exceptions import ModelDownloadError, ModelLoadError logger = logging.getLogger(__name__) class LlamaCppLoader(BaseLoader): """ Loader for llama.cpp models. Handles both local and remote model loading, with support for GGUF format and various optimizations. """ def __init__(self, request: ModelLoadRequest, manager: Any): super().__init__(request, manager) self.llama = None self.cache = None def load(self) -> Tuple[Any, Any]: """Load a llama.cpp model and return the model and tokenizer.""" try: import torch from llama_cpp import Llama # Force CUDA environment variables before anything else if torch.cuda.is_available(): os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['LLAMA_CUDA_FORCE'] = '1' # Log CUDA information logger.info("CUDA is available") logger.info(f"CUDA Device: {torch.cuda.get_device_name(0)}") logger.info( f"Total CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.0f}MB") torch.cuda.empty_cache() # Get model path and ensure it exists model_path = self._get_model_path() if not model_path.exists(): raise ModelLoadError(f"Model file not found: {model_path}") logger.info(f"Loading model from path: {model_path}") # Simple CUDA parameters that match working Q8 configurations model_params = { "model_path": str(model_path), "n_ctx": int(self.request.n_ctx) if self.request.n_ctx is not None else 2048, "n_batch": int(self.request.n_batch) if self.request.n_batch is not None else 512, "n_gpu_layers": -1, "main_gpu": 0, "use_mmap": True, # Enable memory mapping "use_mlock": False, "verbose": True } # Log parameters logger.info(f"Loading model with parameters: {model_params}") # Load model model = Llama(**model_params) logger.info("Initial model load successful") # Simple CUDA test if torch.cuda.is_available(): try: logger.info("Testing model...") # Basic tokenization test tokens = model.tokenize(b"test") logger.info("Tokenization successful") # Log memory usage allocated = torch.cuda.memory_allocated() / 1024**2 reserved = torch.cuda.memory_reserved() / 1024**2 logger.info(f"CUDA Memory allocated: {allocated:.2f}MB") logger.info(f"CUDA Memory reserved: {reserved:.2f}MB") except Exception as e: logger.error(f"Model test failed: {e}") raise ModelLoadError(f"Failed to initialize model: {e}") logger.info("Model loaded successfully") return model, model except Exception as e: logger.error(f"Error loading model: {str(e)}", exc_info=True) raise ModelLoadError(f"Failed to load llama.cpp model: {str(e)}") def _get_model_path(self) -> Path: """Get and validate the model path, downloading if necessary.""" # Handle both direct file paths and model names if self.request.model_path: model_path = Path(self.request.model_path) else: # Convert HF style paths to filesystem paths safe_name = self.request.model_name.replace('/', os.path.sep) model_path = Path('models') / safe_name model_dir = model_path if model_path.is_dir() else model_path.parent model_dir.mkdir(parents=True, exist_ok=True) # Special handling for Ollama paths if '.ollama' in str(model_path): logger.info("Detected Ollama model path") # Determine Ollama directory based on OS if sys.platform == 'darwin': # macOS specific path ollama_dir = Path(os.path.expanduser('~/.ollama')) logger.info(f"Using macOS Ollama directory: {ollama_dir}") else: # Windows and Linux ollama_dir = Path(os.path.expandvars('%USERPROFILE%\\.ollama')) if not ollama_dir.exists(): ollama_dir = Path(os.path.expanduser('~/.ollama')) if not ollama_dir.exists(): raise ModelLoadError( f"Ollama directory not found at: {ollama_dir}") # Extract model name from path model_name = self.request.model_name if not model_name and 'registry.ollama.ai/library/' in str(model_path): model_name = str(model_path).split( 'registry.ollama.ai/library/')[-1].split('/')[0] logger.info(f"Using model name: {model_name}") # First check for the model file in the models directory models_dir = ollama_dir / 'models' logger.info(f"Checking Ollama models directory: {models_dir}") if models_dir.exists(): # First try to find a .gguf file gguf_files = list(models_dir.glob("**/*.gguf")) if gguf_files: logger.info(f"Found Ollama GGUF file: {gguf_files[0]}") return gguf_files[0] # Look for manifest manifest_dir = models_dir / 'manifests' / \ 'registry.ollama.ai' / 'library' / model_name manifest_path = manifest_dir / 'latest' logger.info(f"Looking for manifest at: {manifest_path}") if manifest_path.exists(): with open(manifest_path, 'r') as f: manifest = f.read() logger.info(f"Manifest content: {manifest}") import json try: manifest_data = json.loads(manifest) for layer in manifest_data.get('layers', []): if layer.get('mediaType') == 'application/vnd.ollama.image.model': blob_hash = layer.get('digest', '').replace( 'sha256:', 'sha256-') if blob_hash: # Check both blobs and models directories for the file possible_paths = [ models_dir / 'blobs' / blob_hash, ollama_dir / 'blobs' / blob_hash ] for blob_path in possible_paths: logger.info( f"Checking for blob at: {blob_path}") if blob_path.exists(): logger.info( f"Found Ollama model blob: {blob_path}") return blob_path except json.JSONDecodeError as e: logger.error(f"Failed to parse manifest: {e}") pass logger.warning(f"No Ollama model files found in: {models_dir}") raise ModelLoadError( f"Could not find Ollama model files in {models_dir}") # Check for existing GGUF files in the directory if model_dir.exists(): existing_gguf = list(model_dir.glob("*.gguf")) if existing_gguf: logger.info(f"Found existing GGUF model: {existing_gguf[0]}") return existing_gguf[0] # Only attempt to download if it looks like a HF model ID if '/' in self.request.model_name: return self._download_model(model_dir) raise ModelLoadError(f"No model files found in: {model_dir}") def _download_model(self, model_dir: Path) -> Path: """Download model from Hugging Face.""" logger.info(f"Attempting to download model: {self.request.model_name}") try: # Setup API request api_url = f"https://huggingface.co/api/models/{self.request.model_name}/tree/main" headers = {"Accept": "application/json"} if self.request.hf_token: headers["Authorization"] = f"Bearer {self.request.hf_token}" # Get repository contents response = requests.get(api_url, headers=headers) response.raise_for_status() files = response.json() # Find GGUF files gguf_files = [f for f in files if f.get( 'path', '').endswith('.gguf')] if not gguf_files: raise ModelDownloadError( f"No GGUF files found in repository {self.request.model_name}") # Sort by preference (q4_k_m) and size gguf_files.sort(key=lambda x: ( 0 if 'q4_k_m' in x['path'].lower() else 1, x.get('size', float('inf')) )) # Download the best candidate file_info = gguf_files[0] file_name = file_info['path'] download_url = f"https://huggingface.co/{self.request.model_name}/resolve/main/{file_name}" model_path = model_dir / file_name if not model_path.exists() or model_path.stat().st_size == 0: self._download_file(download_url, model_path, headers) return model_path except Exception as e: raise ModelDownloadError(f"Failed to download model: {str(e)}") def _download_file(self, url: str, path: Path, headers: Dict[str, str]) -> None: """Download a file with progress bar.""" response = requests.get(url, stream=True, headers=headers) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) block_size = 8192 with open(path, 'wb') as f, tqdm( desc=path.name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(block_size): size = f.write(data) pbar.update(size) def _get_model_params(self) -> Dict[str, Any]: """Configure model parameters based on request and system capabilities.""" import torch # Base parameters params = { "n_ctx": int(self.request.n_ctx) if self.request.n_ctx is not None else 2048, "n_batch": int(self.request.n_batch) if self.request.n_batch is not None else 512, "n_threads": int(self.request.n_threads) if self.request.n_threads is not None else os.cpu_count(), "verbose": True, # Enable verbose output for debugging } # Add CUDA parameters if available if torch.cuda.is_available(): logger.info("Configuring CUDA parameters...") # Force CUDA environment variables os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['LLAMA_CUDA_FORCE'] = '1' os.environ['LLAMA_FORCE_GPU'] = '1' # Force GPU usage os.environ['LLAMA_CPU_DISABLE'] = '1' # Disable CPU fallback # Enhanced CUDA parameters - optimized for GPU usage cuda_params = { "n_gpu_layers": -1, # Use all layers on GPU "main_gpu": 0, # Use the first GPU "tensor_split": None, # No tensor splitting "use_mmap": False, # Disable memory mapping "use_mlock": True, # Lock memory to prevent swapping "mul_mat_q": True, # Enable matrix multiplication "offload_kqv": True, # Keep KQV on GPU "f16_kv": True, # Use float16 for KV cache "logits_all": True, # Compute logits for all tokens "embedding": True # Use GPU for embeddings } params.update(cuda_params) logger.info(f"CUDA parameters configured: {cuda_params}") # Log CUDA device info logger.info(f"CUDA Device: {torch.cuda.get_device_name(0)}") logger.info( f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.0f}MB") # Add optional parameters if specified in request optional_params = { "tensor_split": self.request.tensor_split, "split_mode": self.request.split_mode, "cache_type": self.request.cache_type, } # Only add optional params if they have non-None values params.update( {k: v for k, v in optional_params.items() if v is not None}) logger.info(f"Final model parameters: {params}") return params def _configure_gpu_layers(self) -> int: """Configure the number of GPU layers based on hardware and request.""" import torch if not torch.cuda.is_available(): return 0 # Force environment variables for CUDA os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['LLAMA_CUDA_FORCE'] = '1' # If n_gpu_layers is specified in request, use that if self.request.n_gpu_layers is not None: return self.request.n_gpu_layers # Otherwise, use all layers on GPU return -1 # -1 means use all layers on GPU def _setup_cache(self, model: Any) -> None: """Setup model cache if supported.""" try: from llama_cpp import LlamaCache if hasattr(model, 'set_cache'): # Convert GB to bytes cache_size = self.request.cache_size * 1024 * 1024 * 1024 cache_type = "fp16" # or q8_0 or q4_0 depending on your needs model.set_cache(LlamaCache(capacity_bytes=cache_size)) logger.info( f"Initialized LLM cache with {self.request.cache_size}GB capacity using {cache_type}") except Exception as e: logger.warning(f"Failed to initialize cache: {e}") def get_metadata(self) -> Optional[Dict[str, Any]]: """Get model metadata without loading the full model.""" try: model_path = self._get_model_path() if not model_path.exists(): return None # Basic metadata metadata = { "model_type": "llama.cpp", "model_path": str(model_path), "file_size": model_path.stat().st_size, "format": "GGUF" if model_path.suffix == '.gguf' else "Unknown" } # Try to get additional metadata from the GGUF file try: from llama_cpp import Llama model = Llama(model_path=str(model_path), n_ctx=8, n_gpu_layers=0) metadata.update({ "n_vocab": model.n_vocab(), "n_ctx_train": model.n_ctx_train(), "n_embd": model.n_embd(), "desc": model.desc(), }) except: pass return metadata except Exception as e: logger.error(f"Error getting model metadata: {str(e)}") return None def get_config(self) -> Dict[str, Any]: """Get the current model configuration.""" return { "model_type": "llama.cpp", "n_ctx": self.request.n_ctx, "n_batch": self.request.n_batch, "n_gpu_layers": self.request.n_gpu_layers, "device": self.request.device, } @staticmethod def cleanup(model: Any) -> None: """Clean up model resources.""" try: del model except: pass ================================================ FILE: Backend/src/models/loaders/tensorrt.py ================================================ import logging from typing import Any, Dict, Optional, Tuple from src.models.loaders.base import BaseLoader from src.models.exceptions import ModelLoadError from transformers import AutoTokenizer logger = logging.getLogger(__name__) class TensorRTLoader(BaseLoader): """Loader for TensorRT-LLM models.""" def load(self) -> Tuple[Any, Any]: """Load a TensorRT-LLM model.""" try: import tensorrt_llm from tensorrt_llm.runtime import ModelConfig except ImportError: raise ModelLoadError( "tensorrt-llm is not installed. Please install it from the TensorRT-LLM repository") engine_path = self.request.engine_dir if self.request.engine_dir else self.model_path if not engine_path.exists(): raise ModelLoadError(f"Engine path does not exist: {engine_path}") config = ModelConfig( engine_dir=str(engine_path), max_batch_size=self.request.max_batch_size, max_input_len=self.request.max_input_len, max_output_len=int( self.request.max_output_len) if self.request.max_output_len is not None else None, ) model = tensorrt_llm.runtime.GenerationSession(config) tokenizer = AutoTokenizer.from_pretrained( self.request.tokenizer_path or str(engine_path), trust_remote_code=self.request.trust_remote_code, use_fast=self.request.use_fast_tokenizer, ) return model, tokenizer def get_metadata(self) -> Optional[Dict[str, Any]]: """Get model metadata.""" if not self.model_path.exists(): return None return { "model_type": "TensorRT-LLM", "model_path": str(self.model_path), "file_size": self.model_path.stat().st_size, "engine_dir": self.request.engine_dir } def get_config(self) -> Dict[str, Any]: """Get model configuration.""" return { "model_type": "TensorRT-LLM", "model_name": self.request.model_name, "device": self.request.device, "engine_dir": self.request.engine_dir, "max_batch_size": self.request.max_batch_size, "max_input_len": self.request.max_input_len, "max_output_len": self.request.max_output_len } ================================================ FILE: Backend/src/models/loaders/transformers.py ================================================ import logging from pathlib import Path from typing import Any, Dict, Optional, Tuple import torch from transformers import ( BitsAndBytesConfig, PreTrainedModel, ) from src.models.loaders.base import BaseLoader from src.models.exceptions import ModelLoadError logger = logging.getLogger(__name__) class TransformersLoader(BaseLoader): """ Loader for Hugging Face Transformers models. Handles both local and remote model loading with various optimizations. """ def load(self) -> Tuple[Any, Any]: """Load a transformers model and return the model and tokenizer.""" try: from transformers import AutoModelForCausalLM, AutoTokenizer logger.info(f"Loading model: {self.request.model_name}") logger.info(f"Model type: {self.request.model_type}") logger.info(f"Model path: {self.request.model_path}") logger.info(f"Device: {self.request.device}") # Configure model loading parameters model_kwargs = self._get_model_kwargs() # If we have a local path, use it directly if self.request.model_path and Path(self.request.model_path).exists(): logger.info( f"Loading model from local path: {self.request.model_path}") try: # Try to load tokenizer from local path first tokenizer = AutoTokenizer.from_pretrained( self.request.model_path, trust_remote_code=self.request.trust_remote_code, use_fast=self.request.use_fast_tokenizer, padding_side=self.request.padding_side ) logger.info("Loaded tokenizer from local path") # Load model from local path model = AutoModelForCausalLM.from_pretrained( self.request.model_path, **model_kwargs ) logger.info("Loaded model from local path") # Ensure model is on the correct device if not using device_map if model_kwargs.get("device_map") is None and hasattr(model, "to"): # Handle device placement if self.request.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = self.request.device model = model.to(device) logger.info(f"Moved model to device: {device}") return model, tokenizer except Exception as e: logger.warning(f"Failed to load from local path: {e}") raise ModelLoadError( f"Failed to load model from local path: {str(e)}") else: # Download from HuggingFace logger.info( "Attempting to download from HuggingFace: " + self.request.model_name) try: # Download and save tokenizer tokenizer = AutoTokenizer.from_pretrained( self.request.model_name, trust_remote_code=self.request.trust_remote_code, use_fast=self.request.use_fast_tokenizer, padding_side=self.request.padding_side ) if self.request.model_path: tokenizer.save_pretrained(self.request.model_path) logger.info( f"Tokenizer downloaded and saved to {self.request.model_path}") # Download and save config if self.request.model_path: from transformers import AutoConfig config = AutoConfig.from_pretrained( self.request.model_name, trust_remote_code=self.request.trust_remote_code ) config.save_pretrained(self.request.model_path) logger.info( f"Config downloaded and saved to {self.request.model_path}") # Download model weights logger.info( "Downloading model weights (this may take a while)...") model = AutoModelForCausalLM.from_pretrained( self.request.model_name, **model_kwargs ) # Save the model if we have a path if self.request.model_path: model.save_pretrained(self.request.model_path) logger.info( f"Model weights saved to {self.request.model_path}") return model, tokenizer except Exception as e: raise ModelLoadError(f"Failed to download model: {str(e)}") except Exception as e: raise ModelLoadError( f"Failed to load transformers model: {str(e)}") def _get_model_kwargs(self) -> Dict[str, Any]: """Get model loading parameters.""" # Get the compute dtype compute_dtype = torch.bfloat16 if self.request.compute_dtype == "bfloat16" else torch.float16 # Determine device map device_map = None if self.request.device == "cuda": if torch.cuda.is_available(): device_map = "auto" else: logger.warning( "CUDA requested but not available, falling back to CPU") self.request.device = "cpu" # Base parameters without gradient checkpointing load_params = { "low_cpu_mem_usage": True, "torch_dtype": compute_dtype, "trust_remote_code": self.request.trust_remote_code, "use_flash_attention_2": self.request.use_flash_attention, "device_map": device_map, "revision": self.request.revision, } # Only add gradient checkpointing for explicitly supported models model_name_lower = self.request.model_name.lower() if ("llama" in model_name_lower or "mistral" in model_name_lower or "mpt" in model_name_lower): load_params["use_gradient_checkpointing"] = True # Configure quantization if self.request.load_in_8bit or self.request.load_in_4bit: load_params["quantization_config"] = self._get_quantization_config() # Add optional parameters if self.request.max_memory is not None and self.request.device == "cuda": load_params["max_memory"] = self.request.max_memory if self.request.rope_scaling is not None: load_params["rope_scaling"] = self.request.rope_scaling if self.request.use_cache is False: load_params["use_cache"] = False # For model loading, return the original params with torch.dtype if not hasattr(self, '_serializing_for_response'): return load_params # For JSON response, convert torch.dtype to string response_params = load_params.copy() response_params["torch_dtype"] = str(compute_dtype) return response_params # Return string version for JSON serialization def _get_quantization_config(self) -> BitsAndBytesConfig: """Get quantization configuration.""" return BitsAndBytesConfig( load_in_8bit=self.request.load_in_8bit, load_in_4bit=self.request.load_in_4bit, bnb_4bit_compute_dtype=eval(f"torch.{self.request.compute_dtype}"), llm_int8_enable_fp32_cpu_offload=True, bnb_4bit_use_double_quant=True ) def get_metadata(self) -> Optional[Dict[str, Any]]: """Get model metadata without loading the full model.""" try: if '/' in self.request.model_name and not self.model_path.exists(): config = self._load_config(self.request.model_name) metadata = self._make_json_serializable(config.to_dict()) metadata['model_type'] = 'Transformers' return metadata if self.model_path.exists(): config = self._load_config(self.model_path) metadata = self._make_json_serializable(config.to_dict()) metadata['model_type'] = 'Transformers' return metadata return None except Exception as e: logger.error(f"Error getting model metadata: {str(e)}") return None def get_config(self) -> Dict[str, Any]: """Get the current model configuration.""" # Set flag to get JSON serializable params self._serializing_for_response = True load_params = self._get_model_kwargs() delattr(self, '_serializing_for_response') config = { "model_type": "Transformers", "model_name": self.request.model_name, "device": self.request.device, "load_params": load_params } if self.model_path.exists(): try: model_config = self._load_config(self.model_path) config["model_config"] = model_config.to_dict() except Exception as e: logger.warning(f"Could not load model config: {str(e)}") return self._make_json_serializable(config) def _make_json_serializable(self, obj: Any) -> Any: """Convert a dictionary with torch dtypes to JSON serializable format.""" if isinstance(obj, dict): return {k: self._make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [self._make_json_serializable(v) for v in obj] elif hasattr(obj, 'dtype'): # Handle torch dtypes return str(obj) return obj @staticmethod def cleanup(model: PreTrainedModel) -> None: """Clean up model resources.""" try: if hasattr(model, 'cpu'): model.cpu() del model except Exception as e: logger.warning(f"Error during model cleanup: {str(e)}") ================================================ FILE: Backend/src/models/manager.py ================================================ import logging from pathlib import Path from typing import Optional, Tuple, Any, Dict, Union from src.endpoint.models import ModelLoadRequest from src.models.utils.device import get_device from src.models.utils.platform import check_platform_compatibility from src.models.utils.detect_type import detect_model_type from src.models.exceptions import ModelLoadError, ModelNotFoundError from src.models.loaders import ( TransformersLoader, LlamaCppLoader, LlamaCppHFLoader, ExLlamaV2Loader, ExLlamaV2HFLoader, HQQLoader, TensorRTLoader ) logger = logging.getLogger(__name__) class ModelManager: """ Manages the loading, unloading, and switching of different AI models. Supports multiple model types and handles resource management. """ def __init__(self): """Initialize the model manager with empty state.""" self.current_model: Optional[Any] = None self.current_tokenizer: Optional[Any] = None self.model_type: Optional[str] = None self.device: Optional[str] = None self.model_name: Optional[str] = None self._is_loading: bool = False self.model_config: Optional[Dict[str, Any]] = None # Map model types to their respective loaders self.loader_mapping = { 'Transformers': TransformersLoader, 'llama.cpp': LlamaCppLoader, 'llamacpp_HF': LlamaCppHFLoader, 'ExLlamav2': ExLlamaV2Loader, 'ExLlamav2_HF': ExLlamaV2HFLoader, 'HQQ': HQQLoader, 'TensorRT-LLM': TensorRTLoader } def check_platform_compatibility(self, model_type: str) -> Tuple[bool, str]: """Check if the current platform is compatible with the specified model type.""" return check_platform_compatibility(model_type) def get_model_metadata(self, request: ModelLoadRequest) -> Optional[Dict[str, Any]]: """ Get model metadata without loading the full model. Args: request: Model load request containing model information Returns: Dictionary containing model metadata or None if not found """ try: model_path = Path(request.model_path) if request.model_path else Path( f"models/{request.model_name}") # Get the appropriate loader loader_class = self.loader_mapping.get(request.model_type) if loader_class: loader = loader_class(request, self) return loader.get_metadata() return None except Exception as e: logger.error(f"Error getting model metadata: {str(e)}") return None def is_model_loaded(self) -> bool: """Check if a model is currently loaded.""" return self.current_model is not None def get_model_info(self) -> Dict[str, Any]: """ Get information about the currently loaded model. Returns: Dictionary containing model information """ info = { "model_name": self.model_name, "model_type": self.model_type, "device": self.device, "is_loaded": self.is_model_loaded(), "is_loading": self._is_loading, } if self.model_config: info["config"] = self._make_json_serializable(self.model_config) return self._make_json_serializable(info) def clear_model(self) -> None: """Unload the current model and clear CUDA cache.""" try: if self.current_model is not None: # Let the specific loader handle cleanup if method exists loader_class = self.loader_mapping.get(self.model_type) if loader_class and hasattr(loader_class, 'cleanup'): loader_class.cleanup(self.current_model) else: # Default cleanup if hasattr(self.current_model, 'cpu'): self.current_model.cpu() del self.current_model if self.current_tokenizer is not None: del self.current_tokenizer # Reset all attributes self.current_model = None self.current_tokenizer = None self.model_type = None self.device = None self.model_name = None self.model_config = None # Clear CUDA cache if available import torch import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: logger.error(f"Error clearing model: {str(e)}") raise def _make_json_serializable(self, obj: Any) -> Any: """Convert objects to JSON serializable format.""" if isinstance(obj, dict): return {k: self._make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [self._make_json_serializable(v) for v in obj] elif hasattr(obj, 'dtype'): # Handle torch dtypes return str(obj) return obj def load_model(self, request: ModelLoadRequest) -> Tuple[Any, Any]: """ Load a model based on the request configuration. Args: request: Model load request containing all necessary parameters Returns: Tuple of (model, tokenizer) Raises: ModelLoadError: If there's an error during model loading ModelNotFoundError: If the requested model is not found """ if self._is_loading: raise ModelLoadError("A model is already being loaded") try: self._is_loading = True self.clear_model() # Clear any existing model # Set device using imported get_device function self.device = get_device(request) self.model_name = request.model_name # Handle Ollama models first - convert to llama.cpp if request.model_type == 'ollama': try: # Read the manifest to get the blob SHA manifest_path = Path(request.model_path) / 'latest' logger.info(f"Looking for manifest at: {manifest_path}") if not manifest_path.exists(): raise ModelLoadError(f"Manifest file not found at: {manifest_path}") import json with open(manifest_path) as f: manifest = json.load(f) logger.info(f"Manifest content: {json.dumps(manifest, indent=2)}") # Get the model layer (first layer with mediaType 'application/vnd.ollama.image.model') try: model_layer = next(layer for layer in manifest['layers'] if layer['mediaType'] == 'application/vnd.ollama.image.model') except StopIteration: raise ModelLoadError("No model layer found in manifest") # Extract SHA and construct blob path sha = model_layer['digest'].split(':')[1] # Ollama stores the files directly in the blobs directory with a sha256- prefix blob_path = Path(request.model_path).parent.parent.parent.parent / 'blobs' / f'sha256-{sha}' logger.info(f"Looking for blob at: {blob_path}") if not blob_path.exists(): raise ModelLoadError(f"Model file not found at: {blob_path}") # Update the request to use the actual model file request.model_path = str(blob_path) request.model_type = "llama.cpp" logger.info(f"Converting Ollama model to llama.cpp with path: {request.model_path}") except Exception as e: logger.error(f"Error processing Ollama model: {str(e)}") raise ModelLoadError(f"Failed to process Ollama model: {str(e)}") # Check if model exists locally first model_path = Path(request.model_path) if request.model_path else Path(f"models/{request.model_name}") if model_path.exists(): logger.info(f"Found local model at: {model_path}") # Auto-detect model type if not specified if not request.model_type or request.model_type == "auto": request.model_type = self._detect_model_type(request) logger.info(f"Detected model type: {request.model_type}") else: # Only attempt to download if it looks like a HF model ID if '/' in request.model_name: logger.info(f"Model not found locally, will attempt to download from HuggingFace") else: raise ModelNotFoundError(f"Model not found at: {model_path}") # Check platform compatibility is_compatible, message = check_platform_compatibility(request.model_type) if not is_compatible: raise ModelLoadError(message) logger.info(message) # Get the appropriate loader loader_class = self.loader_mapping.get(request.model_type) logger.info(f"Model type: {request.model_type}") logger.info(f"Available loaders: {list(self.loader_mapping.keys())}") if not loader_class: raise ModelLoadError(f"Unsupported model type: {request.model_type}") # Initialize and use the loader loader = loader_class(request, self) model, tokenizer = loader.load() # Store the results self.current_model = model self.current_tokenizer = tokenizer self.model_type = request.model_type # Make config JSON serializable before storing self.model_config = self._make_json_serializable(loader.get_config()) return model, tokenizer except Exception as e: logger.error(f"Error loading model: {str(e)}", exc_info=True) self.clear_model() # Cleanup on failure if isinstance(e, (ModelLoadError, ModelNotFoundError)): raise raise ModelLoadError(str(e)) finally: self._is_loading = False def _detect_model_type(self, request: ModelLoadRequest) -> str: """ Detect the type of model based on the model path and name. Args: request: Model load request Returns: String indicating the detected model type """ model_path = Path(request.model_path) if request.model_path else Path( f"models/{request.model_name}") if model_path.exists(): return detect_model_type(model_path) # Default to Transformers for HF models if '/' in request.model_name: return "Transformers" raise ModelNotFoundError( f"Could not detect model type: {request.model_name}") # Global model manager instance model_manager = ModelManager() ================================================ FILE: Backend/src/models/streamer.py ================================================ import traceback from queue import Queue from threading import Thread from typing import Optional, Callable, Any, List, Union, AsyncIterator, Iterator, Dict import torch import time import asyncio import json import logging logger = logging.getLogger(__name__) class StopNowException(Exception): pass class StreamingStoppingCriteria: """Base class for stopping criteria during text generation""" def __init__(self): pass def __call__(self, input_ids, scores) -> bool: return False class StopOnInterrupt(StreamingStoppingCriteria): """Stopping criteria that checks for interruption signals""" def __init__(self, stop_signal=None): super().__init__() self.stop_signal = stop_signal or (lambda: False) def __call__(self, input_ids, scores) -> bool: return self.stop_signal() class StreamIterator(AsyncIterator[str], Iterator[str]): """Iterator that streams tokens as they are generated.""" def __init__(self, func: Callable, callback: Optional[Callable] = None): self.func = func self.callback = callback self.queue = Queue() self.async_queue = asyncio.Queue() self.sentinel = object() self.stop_now = False self.thread = None def _queue_callback(self, data): """Callback that puts data into both queues""" if self.stop_now: raise StopNowException if data is None: self.queue.put(self.sentinel) self.async_queue.put_nowait(None) return if self.callback: self.callback(data) formatted_data = f"data: {json.dumps(data)}\n\n" self.queue.put(formatted_data) self.async_queue.put_nowait(formatted_data) def _start_generation(self): if not self.thread: def task(): try: self.func(self._queue_callback) except StopNowException: pass except Exception: traceback.print_exc() finally: self._queue_callback(None) self.thread = Thread(target=task) self.thread.start() def __iter__(self) -> Iterator[str]: self._start_generation() return self def __next__(self) -> str: if not self.thread: self._start_generation() item = self.queue.get() if item is self.sentinel: raise StopIteration return item def __aiter__(self): self._start_generation() return self async def __anext__(self) -> str: if not self.thread: self._start_generation() try: item = await self.async_queue.get() if item is None: raise StopAsyncIteration return item except Exception as e: if isinstance(e, StopAsyncIteration): raise raise StopAsyncIteration from e def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop_now = True class TextGenerator: """A text generator that streams tokens as they are generated.""" def __init__(self, model, tokenizer, device: str = "cpu"): self.model = model self.tokenizer = tokenizer self.device = device self.stop_signal = False self._log_cuda_status() def _log_cuda_status(self): """Log CUDA status if available""" if hasattr(torch.cuda, 'is_available') and torch.cuda.is_available(): logger.info("CUDA is available in TextGenerator") logger.info( f"Model GPU layers: {getattr(self.model, 'n_gpu_layers', 'unknown')}") logger.info( f"CUDA Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f}MB") logger.info( f"CUDA Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f}MB") def _create_stream_response(self, text: str, generated_text: str, is_final: bool = False) -> Dict: """Create a standardized streaming response""" response = { "id": "chatcmpl-" + str(hash(generated_text))[-12:], "object": "chat.completion.chunk", "created": int(time.time()), "model": "local-model", "choices": [{ "index": 0, "delta": {} if is_final else {"content": text}, "finish_reason": "stop" if is_final else None }] } return response def _stream_tokens(self, callback: Callable, generator, decode_func: Callable) -> str: """Generic token streaming implementation""" generated_text = "" for output in generator: text = decode_func(output) generated_text += text callback(self._create_stream_response(text, generated_text)) # Send final message callback(self._create_stream_response( "", generated_text, is_final=True)) callback(None) return generated_text def generate(self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.1, stopping_criteria: Optional[List[StreamingStoppingCriteria]] = None, callback: Optional[Callable[[dict], Any]] = None, stream: bool = True) -> Union[str, Any]: """Generate text from a prompt, optionally streaming the output.""" if hasattr(self.model, 'create_completion'): # llama.cpp model completion_args = { "prompt": prompt, "max_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repeat_penalty": repetition_penalty, "stream": stream } if stream: def _stream(callback): completion = self.model.create_completion( **completion_args) return self._stream_tokens( callback, completion, lambda x: x["choices"][0]["text"] ) return StreamIterator(_stream, callback=callback) else: completion = self.model.create_completion(**completion_args) return completion["choices"][0]["text"] else: # Other models (transformers) inputs = self.tokenizer( prompt, return_tensors="pt", padding=True).to(self.device) gen_config = { "max_new_tokens": max_new_tokens, "temperature": max(temperature, 1e-2), "top_p": min(max(top_p, 0.1), 0.95), "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, "use_cache": True } if stream: def _stream(callback): with torch.no_grad(): generator = self.model.generate( **inputs, **gen_config, stopping_criteria=stopping_criteria, return_dict_in_generate=True, output_scores=True ) return self._stream_tokens( callback, generator, lambda x: self.tokenizer.decode( [x.sequences[0, -1].item() if not isinstance(x, torch.Tensor) else x.item()], skip_special_tokens=True ) ) return StreamIterator(_stream, callback=callback) else: with torch.no_grad(): output = self.model.generate( **inputs, **gen_config, stopping_criteria=stopping_criteria, return_dict_in_generate=True, output_scores=True ) return self.tokenizer.decode(output.sequences[0], skip_special_tokens=True) # End of TextGenerator class - everything after this line should be removed ================================================ FILE: Backend/src/models/utils/__init__.py ================================================ ================================================ FILE: Backend/src/models/utils/detect_type.py ================================================ import json from pathlib import Path from typing import Union import logging logger = logging.getLogger(__name__) def detect_model_type(model_path: Union[str, Path]) -> str: """ Detect the model type from the model files and metadata Returns one of: 'ollama', 'Transformers', 'llama.cpp', 'llamacpp_HF', 'ExLlamav2', 'ExLlamav2_HF', 'HQQ', 'TensorRT-LLM' """ model_path = Path(model_path) if not model_path.exists(): raise ValueError(f"Model path does not exist: {model_path}") # Check for model metadata metadata_path = model_path / "metadata.json" if metadata_path.exists(): try: with open(metadata_path, 'r') as f: metadata = json.load(f) if "model_type" in metadata: return metadata["model_type"] except: logger.warning(f"Could not read metadata from {metadata_path}") # Check for specific file patterns files = list(model_path.glob("*")) file_names = [f.name for f in files] # TensorRT-LLM check if any(f.endswith('.engine') for f in file_names) or any(f.endswith('.plan') for f in file_names): return 'TensorRT-LLM' # llama.cpp check if any(f.endswith('.gguf') for f in file_names): # Check if there's a HF tokenizer if any(f == 'tokenizer_config.json' for f in file_names): return 'ExLlamav2_HF' return 'ExLlamav2' # HQQ check if any(f.endswith('.hqq') for f in file_names): return 'HQQ' # Default to Transformers for standard HF models if any(f in file_names for f in ['config.json', 'pytorch_model.bin', 'model.safetensors']): # Only check for ExLlamav2 if we find specific ExLlamav2 files if (model_path / 'tokenizer.model').exists(): config_path = model_path / 'config.json' try: with open(config_path, 'r') as f: config = json.load(f) if config.get('model_type', '').lower() in ['llama', 'mistral']: return 'ExLlamav2' except: pass return 'Transformers' raise ValueError( f"Could not determine model type from files in {model_path}") ================================================ FILE: Backend/src/models/utils/device.py ================================================ import torch from src.endpoint.models import ModelLoadRequest def get_device(request: ModelLoadRequest) -> str: if request.device != "auto": return request.device if torch.cuda.is_available(): print("CUDA is available") return "cuda" elif torch.backends.mps.is_available(): print("MPS is available") return "mps" else: print("No GPU available") return "cpu" ================================================ FILE: Backend/src/models/utils/download.py ================================================ import os import logging import requests from tqdm import tqdm from pathlib import Path from typing import List, Dict, Optional logger = logging.getLogger(__name__) def download_file_with_progress(url: str, file_path: Path, headers: Optional[Dict[str, str]] = None) -> None: """Download a file with progress bar""" try: response = requests.get(url, stream=True, headers=headers or {}) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) block_size = 8192 # 8 KB with open(file_path, 'wb') as f, tqdm( desc=file_path.name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(block_size): size = f.write(data) pbar.update(size) logger.info(f"Successfully downloaded {file_path.name}") except Exception as e: if file_path.exists() and file_path.stat().st_size == 0: file_path.unlink() # Remove empty/partial file raise ValueError(f"Failed to download file {file_path.name}: {str(e)}") def get_hf_repo_files(repo_id: str, hf_token: Optional[str] = None) -> List[Dict]: """Get list of files in a HuggingFace repository""" api_url = f"https://huggingface.co/api/models/{repo_id}/tree/main" headers = {"Accept": "application/json"} if hf_token: headers["Authorization"] = f"Bearer {hf_token}" logger.info("Using provided HuggingFace token") logger.info(f"Fetching repository contents from {api_url}") response = requests.get(api_url, headers=headers) response.raise_for_status() return response.json() def download_hf_model_files(repo_id: str, model_path: Path, required_files: List[str], hf_token: Optional[str] = None) -> None: """Download required files from a HuggingFace repository""" try: files = get_hf_repo_files(repo_id, hf_token) logger.info(f"Found {len(files)} files in repository") logger.info(f"Required files: {required_files}") headers = {} if hf_token: headers["Authorization"] = f"Bearer {hf_token}" for file_name in required_files: file_info = next((f for f in files if f['path'] == file_name), None) if not file_info: logger.error(f"Required file {file_name} not found in repository. Available files: {[f['path'] for f in files]}") raise ValueError(f"Required file {file_name} not found in repository {repo_id}") download_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_name}" file_path = model_path / file_name logger.info(f"Downloading {file_name} ({file_info.get('size', 'unknown size')}) from {download_url}") download_file_with_progress(download_url, file_path, headers) except Exception as e: logger.error(f"Failed to download model: {str(e)}", exc_info=True) # Clean up any partially downloaded files if model_path.exists(): import shutil shutil.rmtree(model_path) raise ValueError(f"Failed to download model: {str(e)}") def find_best_gguf_file(files: List[Dict]) -> Optional[Dict]: """Find the best GGUF file from a list of files, preferring q4_k_m files and sorting by size""" gguf_files = [f for f in files if f.get('path', '').endswith('.gguf')] if not gguf_files: return None # Sort by preference for q4_k_m files and then by size gguf_files.sort(key=lambda x: ( 0 if 'q4_k_m' in x['path'].lower() else 1, x.get('size', float('inf')) )) return gguf_files[0] def download_gguf_model(repo_id: str, model_path: Path, hf_token: Optional[str] = None) -> Path: """Download a GGUF model from HuggingFace""" try: files = get_hf_repo_files(repo_id, hf_token) file_info = find_best_gguf_file(files) if not file_info: raise ValueError(f"No GGUF files found in repository {repo_id}") file_name = file_info['path'] download_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_name}" model_path = model_path / file_name # Only download if file doesn't exist or is empty if not model_path.exists() or model_path.stat().st_size == 0: headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {} download_file_with_progress(download_url, model_path, headers) return model_path except Exception as e: if model_path.exists() and model_path.stat().st_size == 0: model_path.unlink() raise ValueError(f"Failed to download GGUF model: {str(e)}") ================================================ FILE: Backend/src/models/utils/platform.py ================================================ import platform from typing import Tuple def check_platform_compatibility(model_type: str) -> Tuple[bool, str]: """ Check if the model type is compatible with the current platform Returns (is_compatible, message) """ current_platform = platform.system().lower() platform_compatibility = { 'TensorRT-LLM': ['linux'], # TensorRT only works on Linux # ExLlama works on Windows and Linux 'ExLlamav2': ['windows', 'linux'], 'ExLlamav2_HF': ['windows', 'linux'], # HQQ works on all platforms 'HQQ': ['linux', 'windows', 'darwin'], # llama.cpp works on all platforms 'llama.cpp': ['linux', 'windows', 'darwin'], 'llamacpp_HF': ['linux', 'windows', 'darwin'], # Transformers works on all platforms 'Transformers': ['linux', 'windows', 'darwin'], 'ollama': ['linux', 'windows', 'darwin'] } compatible_platforms = platform_compatibility.get(model_type, []) is_compatible = current_platform in compatible_platforms if not is_compatible: message = f"Model type '{model_type}' is not compatible with {platform.system()}. Compatible platforms: {', '.join(compatible_platforms)}" else: message = f"Model type '{model_type}' is compatible with {platform.system()}" return is_compatible, message ================================================ FILE: Backend/src/vectorstorage/embeddings.py ================================================ import time def chunk_list(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i:i + n] def embed_chunk(args): """Embed a chunk of documents.""" vectordb, chunk, chunk_num, total_chunks, start_time, time_history = args try: vectordb.add_documents(chunk) # Calculate time taken for this chunk current_time = time.time() chunk_time = current_time - start_time time_history.append(chunk_time) # Keep only last 5 times if len(time_history) > 5: time_history.popleft() # Basic stats to return for all chunks result = { "chunk": chunk_num, "total_chunks": total_chunks, "docs_in_chunk": len(chunk), "percent_complete": round((chunk_num / total_chunks * 100), 2), "elapsed_time": current_time - start_time, } # Only add time estimates after 20 chunks and if we have enough data points if chunk_num >= 20 and len(time_history) >= 3: current_avg_time = sum(time_history) / len(time_history) # Store the lowest average time seen so far if not hasattr(embed_chunk, 'lowest_avg_time') or current_avg_time < embed_chunk.lowest_avg_time: embed_chunk.lowest_avg_time = current_avg_time remaining_chunks = total_chunks - chunk_num est_remaining_time = remaining_chunks * embed_chunk.lowest_avg_time est_finish_time = time.strftime( '%H:%M:%S', time.localtime(current_time + est_remaining_time)) est_remaining_time_formatted = time.strftime( '%H:%M:%S', time.gmtime(est_remaining_time)) result.update({ "est_finish_time": est_finish_time, "time_per_chunk": embed_chunk.lowest_avg_time, "remaining_chunks": remaining_chunks, "est_remaining_time": est_remaining_time_formatted }) else: result.update({ "est_finish_time": "calculating...", "time_per_chunk": "calculating...", "remaining_chunks": total_chunks - chunk_num, "est_remaining_time": "calculating..." }) return result except Exception as e: raise Exception( f"Error embedding chunk {chunk_num}/{total_chunks}: {str(e)}") ================================================ FILE: Backend/src/vectorstorage/helpers/sanitizeCollectionName.py ================================================ import re def sanitize_collection_name(name): try: sanitized = re.sub(r'[^\w\-]', '_', name) sanitized = re.sub(r'^[^\w]|[^\w]$', '', sanitized) sanitized = re.sub(r'\.{2,}', '_', sanitized) if len(sanitized) < 3: sanitized = sanitized.ljust(3, "_") elif len(sanitized) > 63: sanitized = sanitized[:63] return sanitized except Exception as e: print(f"Error sanitizing collection name: {str(e)}") return None ================================================ FILE: Backend/src/vectorstorage/init_store.py ================================================ from langchain_huggingface import HuggingFaceEmbeddings import logging import torch import os from pathlib import Path logger = logging.getLogger(__name__) def get_models_dir(): if os.name == 'posix': # For Linux, use ~/.local/share/Notate/models if os.uname().sysname == 'Linux': base_dir = os.path.expanduser('~/.local/share/Notate') # For macOS, use ~/Library/Application Support/Notate/models else: base_dir = os.path.expanduser('~/Library/Application Support/Notate') else: # For Windows, use %APPDATA%/Notate base_dir = os.path.expanduser('~/.notate') models_dir = os.path.join(base_dir, 'embeddings_models') os.makedirs(models_dir, exist_ok=True) return models_dir async def init_store(model_name: str = "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5"): logger.info("Initializing HuggingFace embeddings") # Determine the appropriate device if torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" logger.info(f"Using device: {device}") models_dir = get_models_dir() logger.info(f"Using models directory: {models_dir}") model_kwargs = { "device": device } encode_kwargs = { "device": device, "normalize_embeddings": True, "max_seq_length": 512 } try: embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, cache_folder=models_dir ) return embeddings except Exception as e: logger.error(f"Error initializing embeddings: {str(e)}") # Fallback to CPU if there's an error with the device if device != "cpu": logger.info("Falling back to CPU") model_kwargs["device"] = "cpu" encode_kwargs["device"] = "cpu" return HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, cache_folder=models_dir ) raise ================================================ FILE: Backend/src/vectorstorage/vectorstore.py ================================================ from src.vectorstorage.init_store import get_models_dir from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma from langchain_openai import OpenAIEmbeddings import torch import os import logging import platform logger = logging.getLogger(__name__) def get_app_data_dir(): home_dir = os.path.expanduser("~") if platform.system() == "Darwin": # macOS app_data_dir = os.path.join(home_dir, "Library/Application Support/Notate") elif platform.system() == "Linux": # Linux app_data_dir = os.path.join(home_dir, ".local/share/Notate") else: # Windows and others app_data_dir = os.path.join(home_dir, ".notate") os.makedirs(app_data_dir, exist_ok=True) return app_data_dir chroma_db_path = os.path.join(get_app_data_dir(), "chroma_db") logger.info(f"Using Chroma DB path: {chroma_db_path}") def get_vectorstore(api_key: str, collection_name: str, use_local_embeddings: bool = False, local_embedding_model: str = "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5"): try: # Get embeddings if use_local_embeddings or api_key is None: logger.info(f"Using local embedding model: {local_embedding_model}") # Determine the appropriate device if torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" logger.info(f"Using device: {device}") models_dir = get_models_dir() logger.info(f"Using models directory: {models_dir}") model_kwargs = {"device": device} encode_kwargs = { "device": device, "normalize_embeddings": True, "max_seq_length": 512 } try: embeddings = HuggingFaceEmbeddings( model_name=local_embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, cache_folder=models_dir ) except Exception as e: logger.error(f"Error initializing embeddings with {device}: {str(e)}") if device != "cpu": logger.info("Falling back to CPU") model_kwargs["device"] = "cpu" encode_kwargs["device"] = "cpu" embeddings = HuggingFaceEmbeddings( model_name=local_embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, cache_folder=models_dir ) else: raise else: logger.info("Using OpenAI embedding model") embeddings = OpenAIEmbeddings(api_key=api_key) # Try to create vectorstore with specific settings try: from chromadb.config import Settings import chromadb # Use in-memory store if persistent store fails try: chroma_client = chromadb.PersistentClient( path=chroma_db_path, settings=Settings( anonymized_telemetry=False, allow_reset=True, is_persistent=True ) ) except Exception as e: logger.warning(f"Failed to create persistent client: {str(e)}, falling back to in-memory") chroma_client = chromadb.Client( settings=Settings( anonymized_telemetry=False, allow_reset=True, is_persistent=False ) ) vectorstore = Chroma( client=chroma_client, embedding_function=embeddings, collection_name=collection_name, ) logger.info(f"Successfully initialized vectorstore for collection: {collection_name}") return vectorstore except Exception as e: logger.error(f"Error creating Chroma instance: {str(e)}") # Try one more time with in-memory store try: chroma_client = chromadb.Client( settings=Settings( anonymized_telemetry=False, allow_reset=True, is_persistent=False ) ) vectorstore = Chroma( client=chroma_client, embedding_function=embeddings, collection_name=collection_name, ) return vectorstore except Exception as e2: logger.error(f"Error creating in-memory Chroma instance: {str(e2)}") return None except Exception as e: logger.error(f"Error getting vectorstore: {str(e)}") return None ================================================ FILE: Backend/src/voice/voice_to_text.py ================================================ import whisper import os import warnings import torch import shutil import subprocess # Suppress specific warnings warnings.filterwarnings( "ignore", message=".*weights_only=False.*", category=FutureWarning) warnings.filterwarnings( "ignore", message="FP16 is not supported on CPU; using FP32 instead") warnings.filterwarnings("ignore", category=FutureWarning, module="torch.serialization") # Global variables model = None ffmpeg_path = None def initialize_model(model_name: str = "base"): """Initialize the Whisper model with optimal device and precision settings.""" global model, ffmpeg_path if model is None: # Get FFmpeg path from environment variable ffmpeg_path = os.environ.get('FFMPEG_PATH') if ffmpeg_path: try: # Verify FFmpeg works subprocess.run([ffmpeg_path, "-version"], capture_output=True, check=True) print(f"FFmpeg verified at: {ffmpeg_path}") # Set environment variables for Whisper os.environ["PATH"] = os.pathsep.join( [os.path.dirname(ffmpeg_path), os.environ.get('PATH', '')]) os.environ["FFMPEG_BINARY"] = ffmpeg_path except Exception as e: print(f"Warning: Error verifying FFmpeg at {ffmpeg_path}: {e}") ffmpeg_path = None if not ffmpeg_path: # Try to find system FFmpeg ffmpeg_system = shutil.which('ffmpeg') if ffmpeg_system: ffmpeg_path = ffmpeg_system os.environ["FFMPEG_BINARY"] = ffmpeg_path print(f"Using system FFmpeg from: {ffmpeg_path}") else: print("FFmpeg not found or not working") return None # Initialize Whisper model device = "cuda" if torch.cuda.is_available() else "cpu" fp16 = device == "cuda" print(f"Loading Whisper model '{model_name}' on {device}...") model = whisper.load_model(model_name) model.to(device) if device == "cuda" and fp16: model = model.half() print(f"Using GPU with FP16={fp16}") else: print("Using CPU with FP32") return model ================================================ FILE: Backend/tests/testApi.py ================================================ import pytest from fastapi.testclient import TestClient from main import app from src.endpoint.models import EmbeddingRequest, QueryRequest, YoutubeTranscriptRequest client = TestClient(app) def test_embed_endpoint(): # Test successful embedding data = EmbeddingRequest( file_path="test_file.txt", api_key="test_api_key", collection=1, collection_name="test_collection", user=1, metadata={"title": "Test Document"} ) response = client.post("/embed", json=data.dict()) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] def test_concurrent_embedding(): # Test that only one embedding process can run at a time data = EmbeddingRequest( file_path="test_file.txt", api_key="test_api_key", collection=1, collection_name="test_collection", user=1, metadata={"title": "Test Document"} ) # Start first embedding response1 = client.post("/embed", json=data.dict()) assert response1.status_code == 200 # Try to start second embedding response2 = client.post("/embed", json=data.dict()) assert response2.status_code == 200 response_data = response2.json() assert response_data["status"] == "error" assert response_data["message"] == "An embedding process is already running" def test_youtube_ingest(): data = YoutubeTranscriptRequest( url="https://www.youtube.com/watch?v=test_id", user_id=1, collection_id=1, username="test_user", collection_name="test_collection", api_key="test_api_key" ) response = client.post("/youtube-ingest", json=data.dict()) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] def test_cancel_embedding(): # Test cancelling when no embedding is running response = client.post("/cancel-embed") assert response.status_code == 200 response_data = response.json() assert response_data["status"] == "error" assert response_data["message"] == "No embedding process running" # Start an embedding process embed_data = EmbeddingRequest( file_path="test_file.txt", api_key="test_api_key", collection=1, collection_name="test_collection", user=1, metadata={"title": "Test Document"} ) embed_response = client.post("/embed", json=embed_data.dict()) assert embed_response.status_code == 200 # Cancel the embedding process cancel_response = client.post("/cancel-embed") assert cancel_response.status_code == 200 cancel_data = cancel_response.json() assert cancel_data["status"] == "success" assert cancel_data["message"] == "Embedding process cancelled" def test_query(): data = QueryRequest( query="test query", collection=1, collection_name="test_collection", user=1, api_key="test_api_key", top_k=5 ) response = client.post("/vector-query", json=data.dict()) assert response.status_code == 200 # Test error handling invalid_data = QueryRequest( query="", # Empty query should raise an error collection=1, collection_name="test_collection", user=1, api_key="test_api_key", top_k=5 ) response = client.post("/vector-query", json=invalid_data.dict()) assert response.status_code == 200 # FastAPI still returns 200 but with error message response_data = response.json() assert response_data["status"] == "error" # Note: We don't test the restart-server endpoint directly as it would terminate our test process ================================================ FILE: Backend/tests/test_voice.py ================================================ import pytest from fastapi.testclient import TestClient from main import app import os import tempfile import wave import numpy as np import sounddevice as sd client = TestClient(app) def create_test_wav(duration=3.0, frequency=440.0, sample_rate=16000): """Create a test WAV file with a sine wave.""" # Generate time array t = np.linspace(0, duration, int(sample_rate * duration), False) # Generate sine wave note = np.sin(2 * np.pi * frequency * t) # Normalize to 16-bit range and convert to integers audio = note * 32767 audio = audio.astype(np.int16) # Create a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') # Write WAV file with wave.open(temp_file.name, 'wb') as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 2 bytes per sample (16-bit) wav_file.setframerate(sample_rate) wav_file.writeframes(audio.tobytes()) return temp_file.name def test_voice_to_text_basic(): """Test basic voice-to-text functionality with a generated WAV file.""" # Create a test WAV file test_file = create_test_wav() try: with open(test_file, 'rb') as f: files = {'audio_file': ('test.wav', f, 'audio/wav')} response = client.post("/voice-to-text", files=files) assert response.status_code == 200 result = response.json() assert "status" in result assert "text" in result assert "language" in result assert "segments" in result finally: # Clean up the test file os.unlink(test_file) def test_voice_to_text_models(): """Test voice-to-text with different Whisper models.""" test_file = create_test_wav() try: models = ['tiny', 'base', 'small'] # We'll test with smaller models for speed for model in models: with open(test_file, 'rb') as f: files = {'audio_file': ('test.wav', f, 'audio/wav')} response = client.post("/voice-to-text", files=files, data={'model_name': model}) assert response.status_code == 200 result = response.json() assert result["status"] == "success" finally: os.unlink(test_file) def test_voice_to_text_invalid_audio(): """Test voice-to-text with invalid audio data.""" # Create an invalid audio file with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: temp_file.write(b'This is not valid audio data') try: with open(temp_file.name, 'rb') as f: files = {'audio_file': ('invalid.wav', f, 'audio/wav')} response = client.post("/voice-to-text", files=files) assert response.status_code == 200 result = response.json() assert result["status"] == "error" assert "error" in result finally: os.unlink(temp_file.name) def test_voice_to_text_missing_file(): """Test voice-to-text without providing an audio file.""" response = client.post("/voice-to-text") assert response.status_code == 422 # FastAPI validation error def test_voice_to_text_long_audio(): """Test voice-to-text with a longer audio file.""" test_file = create_test_wav(duration=10.0) # 10 seconds try: with open(test_file, 'rb') as f: files = {'audio_file': ('long.wav', f, 'audio/wav')} response = client.post("/voice-to-text", files=files) assert response.status_code == 200 result = response.json() assert result["status"] == "success" assert "text" in result assert "language" in result assert "segments" in result finally: os.unlink(test_file) def test_voice_to_text_different_frequencies(): """Test voice-to-text with different audio frequencies.""" frequencies = [440.0, 880.0, 1760.0] # A4, A5, A6 notes for freq in frequencies: test_file = create_test_wav(frequency=freq) try: with open(test_file, 'rb') as f: files = {'audio_file': (f'freq_{freq}.wav', f, 'audio/wav')} response = client.post("/voice-to-text", files=files) assert response.status_code == 200 result = response.json() assert result["status"] == "success" finally: os.unlink(test_file) def record_audio(duration=5, sample_rate=16000): """Record audio from the microphone.""" print(f"Recording for {duration} seconds...") audio_data = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1, dtype=np.int16) sd.wait() # Wait until recording is finished return audio_data def test_live_voice_to_text(capsys): """Test voice-to-text with live microphone input.""" # Record audio sample_rate = 16000 duration = 5 # 5 seconds of recording with capsys.disabled(): print("\n=== Live Voice-to-Text Test ===") print("Please speak into your microphone...") audio_data = record_audio(duration, sample_rate) # Create a temporary WAV file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') try: # Save the recorded audio to WAV file with wave.open(temp_file.name, 'wb') as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(sample_rate) wav_file.writeframes(audio_data.tobytes()) # Send the recorded audio for transcription with open(temp_file.name, 'rb') as f: files = {'audio_file': ('recording.wav', f, 'audio/wav')} response = client.post("/voice-to-text", files=files) assert response.status_code == 200 result = response.json() assert result["status"] == "success" assert "text" in result with capsys.disabled(): print(f"\nTranscribed text: {result['text']}") print("================================\n") finally: os.unlink(temp_file.name) if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: Frontend/.gitignore ================================================ # Logs logs *.log npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* lerna-debug.log* node_modules dist dist-react dist-ssr dist-electron *.local # Editor directories and files .vscode/* !.vscode/extensions.json .idea .DS_Store *.suo *.ntvs* *.njsproj *.sln *.sw? /test-results/ /playwright-report/ /blob-report/ /playwright/.cache/ ================================================ FILE: Frontend/components.json ================================================ { "$schema": "https://ui.shadcn.com/schema.json", "style": "new-york", "rsc": false, "tsx": true, "tailwind": { "config": "tailwind.config.js", "css": "src/app/index.css", "baseColor": "neutral", "cssVariables": true, "prefix": "" }, "aliases": { "components": "@/components", "utils": "@/lib/utils", "ui": "@/components/ui", "lib": "@/lib", "hooks": "@/hooks" }, "iconLibrary": "lucide" } ================================================ FILE: Frontend/e2e/app.spec.ts ================================================ import { test, expect, _electron, Page, ElectronApplication, } from "@playwright/test"; let electronApp: ElectronApplication; let loadingWindow: Page; let mainWindow: Page; // Increase timeout for the entire test file test.setTimeout(160000); async function waitForMainWindow(timeout = 45000): Promise { const startTime = Date.now(); while (Date.now() - startTime < timeout) { const windows = await electronApp.windows(); // Find the window that's not the loading window const mainWin = windows.find((win) => win !== loadingWindow); if (mainWin) { return mainWin; } await new Promise((resolve) => setTimeout(resolve, 100)); } throw new Error("Main window did not appear within timeout"); } async function waitForPreloadScript(page: Page): Promise { const timeout = 30000; const startTime = Date.now(); return new Promise((resolve, reject) => { const interval = setInterval(async () => { try { if (Date.now() - startTime > timeout) { clearInterval(interval); reject(new Error("Timeout waiting for preload script")); return; } const electronBridge = await page.evaluate(() => { return (window as { electron?: unknown }).electron; }); if (electronBridge) { clearInterval(interval); resolve(electronBridge); } } catch (error) { clearInterval(interval); reject(error); } }, 100); }); } test.beforeEach(async () => { // Launch the app with increased timeout electronApp = await _electron.launch({ args: ["."], env: { NODE_ENV: "development" }, timeout: 45000, }); // Get the loading window (first window) loadingWindow = await electronApp.firstWindow(); // Wait for loading window to be ready and verify its existence await loadingWindow.waitForLoadState("domcontentloaded"); try { // Verify loading window content before it potentially closes const loadingContent = await loadingWindow.textContent("body"); expect(loadingContent).toBeTruthy(); } catch (error) { console.log("Loading window content check failed:", error); } // Wait for Python server to start and main window to appear mainWindow = await waitForMainWindow(); await mainWindow.waitForLoadState("domcontentloaded"); await waitForPreloadScript(mainWindow); }); test.afterEach(async () => { if (electronApp) { await electronApp.close(); } }); test("application startup sequence", async () => { // Verify main window appears and is loaded await mainWindow.waitForLoadState("domcontentloaded"); // Verify main window has expected title const title = await mainWindow.title(); expect(title).toBe("Notate"); // Verify window count const windows = await electronApp.windows(); expect(windows.length).toBeGreaterThanOrEqual(1); }); test("main window functionality after startup", async () => { // Wait for main window to be ready await mainWindow.waitForLoadState("domcontentloaded"); // Get all windows and verify main window state const isMinimized = await electronApp.evaluate(({ BrowserWindow }) => { const wins = BrowserWindow.getAllWindows(); // Find the window that's not minimized (should be our main window) const mainWin = wins.find((win) => !win.isMinimized()); return mainWin ? mainWin.isMinimized() : null; }); expect(isMinimized).toBe(false); }); test("menu structure verification", async () => { // Get the application menu interface MenuItem { label: string; submenuLabels: string[]; } const menu = await electronApp.evaluate(({ Menu }) => { const appMenu = Menu.getApplicationMenu(); if (!appMenu) return null; return appMenu.items.map((item) => ({ label: item.label, submenuLabels: item.submenu?.items.map((subItem) => subItem.label) || [], })); }); // Verify menu exists expect(menu).toBeTruthy(); expect(Array.isArray(menu)).toBe(true); // Verify File menu const fileMenu = menu?.find((item) => item.label === "File") as MenuItem; expect(fileMenu).toBeTruthy(); expect(fileMenu.label).toBe("File"); expect(fileMenu.submenuLabels).toContain("Change User"); expect(fileMenu.submenuLabels).toContain("Quit"); // Verify Edit menu const editMenu = menu?.find((item) => item.label === "Edit") as MenuItem; expect(editMenu).toBeTruthy(); expect(editMenu.label).toBe("Edit"); expect(editMenu.submenuLabels).toContain("Undo"); expect(editMenu.submenuLabels).toContain("Redo"); expect(editMenu.submenuLabels).toContain("Cut"); expect(editMenu.submenuLabels).toContain("Copy"); expect(editMenu.submenuLabels).toContain("Paste"); expect(editMenu.submenuLabels).toContain("Delete"); expect(editMenu.submenuLabels).toContain("Select All"); // Verify View menu const viewMenu = menu?.find((item) => item.label === "View") as MenuItem; expect(viewMenu).toBeTruthy(); expect(viewMenu.label).toBe("View"); expect(viewMenu.submenuLabels).toContain("Chat"); expect(viewMenu.submenuLabels).toContain("History"); expect(viewMenu.submenuLabels).toContain("Temp DevTools"); }); test("menu DevTools functionality", async () => { // Test menu functionality - Toggle DevTools const devToolsVisible = await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); return win?.webContents.isDevToolsOpened() || false; }); expect(devToolsVisible).toBe(false); // Toggle DevTools through menu await electronApp.evaluate(async ({ Menu, BrowserWindow }) => { const appMenu = Menu.getApplicationMenu(); if (!appMenu) return; const viewMenu = appMenu.items.find((item) => item.label === "View"); if (!viewMenu?.submenu) return; const devToolsItem = viewMenu.submenu.items.find( (item) => item.label === "Temp DevTools" ); if (devToolsItem) { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); if (win) { win.webContents.toggleDevTools(); // Add a longer wait time for DevTools to open await new Promise(resolve => setTimeout(resolve, 2000)); } } }); // Verify DevTools is now open const devToolsNowVisible = await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); return win?.webContents.isDevToolsOpened() || false; }); expect(devToolsNowVisible).toBe(true); }); test("menu View functionality", async () => { // Wait for initial load await mainWindow.waitForLoadState("domcontentloaded"); // Test View menu functionality - Chat view const chatClicked = await electronApp.evaluate(async ({ Menu }) => { try { const appMenu = Menu.getApplicationMenu(); if (!appMenu) return false; const viewMenu = appMenu.items.find((item) => item.label === "View"); if (!viewMenu?.submenu) return false; const chatItem = viewMenu.submenu.items.find( (item) => item.label === "Chat" ); if (!chatItem) return false; await chatItem.click(); return true; } catch (error) { console.error("Error clicking Chat menu item:", error); return false; } }); expect(chatClicked).toBe(true); // Add a small delay to allow for view change await new Promise((resolve) => setTimeout(resolve, 1000)); // Verify the view changed to Chat const isChatView = await mainWindow.evaluate(() => { // Try multiple possible selectors return Boolean( document.querySelector('[data-view="Chat"]') || document.querySelector(".chat-view") || document.querySelector("#chat-view") || // Look for any element containing "Chat" text in a heading Array.from(document.querySelectorAll("h1,h2,h3,h4,h5,h6")).some((el) => el.textContent?.includes("Notate") ) ); }); expect(isChatView).toBe(true); // Additional verification - try to find chat-related elements const hasChatElements = await mainWindow.evaluate(() => { return Boolean( document.querySelector('input[type="text"]') || // Chat input document.querySelector("textarea") || // Chat input document.querySelector(".message") || // Chat messages document.querySelector(".chat-container") // Chat container ); }); expect(hasChatElements).toBe(true); }); test("menu Change User functionality", async () => { // Test menu functionality - Change User // Note: This will close the app, so it should be the last test const changeUserClicked = await electronApp.evaluate(async ({ Menu }) => { try { const appMenu = Menu.getApplicationMenu(); if (!appMenu) return false; const fileMenu = appMenu.items.find((item) => item.label === "File"); if (!fileMenu?.submenu) return false; const changeUserItem = fileMenu.submenu.items.find( (item) => item.label === "Change User" ); if (!changeUserItem) return false; await changeUserItem.click(); return true; } catch (error) { console.error("Error clicking Change User menu item:", error); return false; } }); expect(changeUserClicked).toBe(true); }); test("keyboard shortcuts and DevTools functionality", async () => { // Test common keyboard shortcuts await mainWindow.keyboard.press("Control+Z"); // Test Undo await mainWindow.keyboard.press("Control+Y"); // Test Redo await mainWindow.keyboard.press("Control+A"); // Test Select All // Test DevTools using Electron API directly await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); if (win && !win.webContents.isDevToolsOpened()) { win.webContents.openDevTools(); } }); // Add a small delay to allow DevTools to open await new Promise((resolve) => setTimeout(resolve, 1000)); const devToolsOpen = await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); return win?.webContents.isDevToolsOpened() || false; }); expect(devToolsOpen).toBe(true); // Close DevTools await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); if (win && win.webContents.isDevToolsOpened()) { win.webContents.closeDevTools(); } }); // Verify DevTools is closed const devToolsClosed = await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); return !win?.webContents.isDevToolsOpened(); }); expect(devToolsClosed).toBe(true); }); test("window state management", async () => { // Test minimize with retry logic let retries = 3; let isMinimized = false; while (retries > 0 && !isMinimized) { await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); if (win && !win.isMinimized()) { win.minimize(); } }); // Wait longer for the window state to change await new Promise((resolve) => setTimeout(resolve, 2000)); isMinimized = await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); return win?.isMinimized() || false; }); retries--; } expect(isMinimized).toBe(true); // Test restore await electronApp.evaluate(({ BrowserWindow }) => { const win = BrowserWindow.getAllWindows().find((w) => !w.isDestroyed()); win?.restore(); }); }); test("chat interaction flow", async () => { // Set up response mocking await mainWindow.route("**/chat", async (route) => { await route.fulfill({ status: 200, contentType: "application/json", body: JSON.stringify({ id: 1, messages: [ { role: "user", content: "Test message", timestamp: new Date().toISOString(), }, { role: "assistant", content: "This is a mock AI response", timestamp: new Date().toISOString(), }, ], title: "Test Conversation", }), }); }); // Navigate to chat view and wait for it to be ready const chatClicked = await electronApp.evaluate(async ({ Menu }) => { try { const appMenu = Menu.getApplicationMenu(); const viewMenu = appMenu?.items.find((item) => item.label === "View"); const chatItem = viewMenu?.submenu?.items.find( (item) => item.label === "Chat" ); if (!chatItem) return false; await chatItem.click(); return true; } catch (error) { console.error("Error clicking Chat menu item:", error); return false; } }); expect(chatClicked).toBe(true); // Wait for chat interface to load const chatInput = await mainWindow.waitForSelector( '[data-testid="chat-input"]', { timeout: 10000, state: "visible", } ); expect(chatInput).toBeTruthy(); // Type the message await chatInput.type("Test message"); // Click the send button instead of pressing Enter const sendButton = await mainWindow.waitForSelector( '[data-testid="chat-submit"]', { timeout: 5000, state: "visible", } ); expect(sendButton).toBeTruthy(); await sendButton.click(); // Add debug logging console.log("Waiting for user message to appear..."); // Wait for user message to appear with increased timeout const userMessage = await mainWindow.waitForSelector( [ '[data-testid="chat-message-user"]', '[data-testid="message-content-user"]', ".user-message", '.message:has-text("Test message")', ].join(","), { timeout: 20000, state: "visible", } ); // Add more debug logging console.log("User message found, checking content..."); expect(userMessage).toBeTruthy(); // Get all text content to debug const pageContent = await mainWindow.textContent("body"); console.log("Page content:", pageContent); // Verify the message content const messageText = await userMessage.textContent(); console.log("Message text:", messageText); expect(messageText).toContain("Test message"); // Clean up route await mainWindow.unroute("**/chat"); }); test("history view functionality", async () => { // Navigate to history view const historyClicked = await electronApp.evaluate(async ({ Menu }) => { try { const appMenu = Menu.getApplicationMenu(); const viewMenu = appMenu?.items.find((item) => item.label === "View"); const historyItem = viewMenu?.submenu?.items.find( (item) => item.label === "History" ); if (!historyItem) return false; await historyItem.click(); return true; } catch (error) { console.error("Error clicking History menu item:", error); return false; } }); expect(historyClicked).toBe(true); // Wait for history view to be visible const historyView = await mainWindow.waitForSelector( '[data-testid="history-view"]', { timeout: 10000, state: "visible", } ); expect(historyView).toBeTruthy(); // Wait for the header to be visible const header = await mainWindow.waitForSelector( 'h1:has-text("Chat History")', { timeout: 5000, state: "visible", } ); expect(header).toBeTruthy(); // Wait for the search input to be visible const searchInput = await mainWindow.waitForSelector( 'input[type="text"][placeholder="Search conversations..."]', { timeout: 5000, state: "visible", } ); expect(searchInput).toBeTruthy(); // Verify the scroll area exists using multiple possible selectors const scrollArea = await mainWindow.waitForSelector( [ '[data-testid="history-scroll-area"]', ".scroll-area", ".scrollarea", '[role="scrollarea"]', ".overflow-auto", ].join(","), { timeout: 10000, state: "visible", } ); expect(scrollArea).toBeTruthy(); // Test search functionality await searchInput.type("test"); await new Promise((resolve) => setTimeout(resolve, 500)); // Wait for search to update // Get the entire history view content const historyContent = await historyView.textContent(); expect(historyContent).toContain("Chat History"); }); ================================================ FILE: Frontend/electron-builder.json ================================================ { "appId": "com.electron.notate", "productName": "Notate", "extraResources": [ "dist-electron/preload.cjs", { "from": "src/assets", "to": "assets" }, { "from": "../Backend", "to": "Backend", "filter": [ "**/*", "!**/__pycache__", "!**/*.pyc" ] }, { "from": "node_modules/ffmpeg-static/ffmpeg", "to": "ffmpeg" }, { "from": "node_modules/ffmpeg-static/ffmpeg.exe", "to": "ffmpeg.exe" } ], "asarUnpack": [ "Backend", "ffmpeg", "ffmpeg.exe" ], "files": [ "dist-electron", "dist-react", "src/assets/**/*", "build/icons/*" ], "icon": "./build/icons/icon.icns", "mac": { "icon": "./build/icons/icon.icns", "target": "dmg" }, "win": { "icon": "./build/icons/icon.ico", "target": [ "portable", { "target": "nsis", "arch": ["x64"] } ] }, "linux": { "target": [ "AppImage", { "target": "deb", "arch": [ "x64" ] }, { "target": "rpm", "arch": [ "x64" ] } ], "icon": "build/icons/icon.png", "category": "Utility", "executableName": "notate", "desktop": { "Name": "Notate", "Comment": "Notate Application", "Categories": "Utility;", "Type": "Application", "StartupWMClass": "Notate", "Icon": "notate", "Terminal": "false" } } } ================================================ FILE: Frontend/eslint.config.js ================================================ import js from '@eslint/js' import globals from 'globals' import reactHooks from 'eslint-plugin-react-hooks' import reactRefresh from 'eslint-plugin-react-refresh' import tseslint from 'typescript-eslint' export default tseslint.config( { ignores: ['dist'] }, { extends: [js.configs.recommended, ...tseslint.configs.recommended], files: ['**/*.{ts,tsx}'], languageOptions: { ecmaVersion: 2020, globals: globals.browser, }, plugins: { 'react-hooks': reactHooks, 'react-refresh': reactRefresh, }, rules: { ...reactHooks.configs.recommended.rules, 'react-refresh/only-export-components': [ 'warn', { allowConstantExport: true }, ], }, }, ) ================================================ FILE: Frontend/index.html ================================================ Notate
================================================ FILE: Frontend/package.json ================================================ { "name": "notate", "description": "Notate is a cross-platform chatbot that can help assist in your research", "author": "Hairetsu ", "license": "MIT", "homepage": "https://github.com/CNTRLAI/notate", "private": true, "version": "1.1.31", "type": "module", "main": "dist-electron/main.js", "scripts": { "test:unit": "vitest src", "test:e2e": "playwright test", "dev:mac": "npm-run-all --parallel dev:react dev:electron-mac", "dev:win": "npm-run-all --parallel dev:react dev:electron-win", "dev:linux": "npm-run-all --parallel dev:react dev:electron-linux", "dev:react": "vite", "dev:electron-mac": "npm run transpile:electron && NODE_ENV=development electron .", "dev:electron-win": "npm run transpile:electron && cross-env NODE_ENV=development electron .", "dev:electron-linux": "npm run transpile:electron && NODE_ENV=development electron .", "build": "tsc -b && vite build", "lint": "eslint .", "preview": "vite preview", "transpile:electron": "tsc --project src/electron/tsconfig.json", "dist:mac": "npm run transpile:electron && npm run build && electron-builder --mac --arm64", "dist:win": "npm run transpile:electron && npm run build && electron-builder --win --x64 --publish never", "dist:linux": "npm run transpile:electron && npm run build && electron-builder --linux --x64" }, "dependencies": { "@anthropic-ai/sdk": "^0.32.1", "@dqbd/tiktoken": "^1.0.18", "@google/generative-ai": "^0.21.0", "@hookform/resolvers": "^3.9.1", "@radix-ui/react-avatar": "^1.1.1", "@radix-ui/react-dialog": "^1.1.5", "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-menubar": "^1.1.4", "@radix-ui/react-popover": "^1.1.2", "@radix-ui/react-progress": "^1.1.1", "@radix-ui/react-radio-group": "^1.2.2", "@radix-ui/react-scroll-area": "^1.2.1", "@radix-ui/react-select": "^2.1.2", "@radix-ui/react-separator": "^1.1.0", "@radix-ui/react-slider": "^1.2.1", "@radix-ui/react-slot": "^1.1.0", "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", "@radix-ui/react-toast": "^1.2.2", "@radix-ui/react-tooltip": "^1.1.6", "better-sqlite3": "^11.7.0", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "1.0.0", "date-fns": "^4.1.0", "dotenv": "^16.4.7", "electron-log": "^5.2.4", "ffmpeg-static": "^5.2.0", "framer-motion": "^11.15.0", "jsonwebtoken": "^9.0.2", "lucide-react": "^0.462.0", "next-mdx-remote": "^5.0.0", "ollama": "^0.5.12", "openai": "^4.82.0", "os-utils": "^0.0.14", "playwright": "^1.50.1", "react": "^18.3.1", "react-dom": "^18.3.1", "react-dropzone": "^14.3.5", "react-hook-form": "^7.53.2", "react-markdown": "^9.0.3", "rehype-format": "^5.0.1", "rehype-raw": "^7.0.0", "rehype-sanitize": "^6.0.0", "rehype-stringify": "^10.0.1", "remark-frontmatter": "^5.0.0", "remark-gfm": "^4.0.0", "remark-math": "^6.0.0", "remark-parse": "^11.0.0", "remark-rehype": "^11.1.1", "tailwind-merge": "^2.5.5", "tailwindcss-animate": "^1.0.7", "unified": "^11.0.5", "use-clipboard-copy": "^0.2.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" }, "devDependencies": { "@electron/rebuild": "^3.7.1", "@eslint/js": "^9.15.0", "@playwright/test": "^1.49.0", "@tailwindcss/typography": "^0.5.16", "@types/better-sqlite3": "^7.6.12", "@types/hast": "^3.0.4", "@types/jsonwebtoken": "^9.0.7", "@types/node": "^22.10.1", "@types/os-utils": "^0.0.4", "@types/react": "^18.3.12", "@types/react-dom": "^18.3.1", "@types/unist": "^3.0.3", "@vitejs/plugin-react": "^4.3.4", "autoprefixer": "^10.4.20", "cross-env": "^7.0.3", "electron": "^33.2.1", "electron-builder": "^25.1.8", "electron-rebuild": "^3.2.9", "eslint": "^9.15.0", "eslint-import-resolver-typescript": "^3.6.3", "eslint-plugin-import": "^2.31.0", "eslint-plugin-react-hooks": "^5.0.0", "eslint-plugin-react-refresh": "^0.4.14", "globals": "^15.12.0", "npm-run-all": "^4.1.5", "postcss": "^8.4.49", "shiki": "^1.24.0", "tailwindcss": "^3.4.15", "typescript": "~5.6.2", "typescript-eslint": "^8.15.0", "unist-util-visit": "^5.0.0", "vite": "^6.0.11", "vitest": "^3.0.5" } } ================================================ FILE: Frontend/playwright.config.ts ================================================ import { defineConfig, devices } from "@playwright/test"; /** * Read environment variables from file. * https://github.com/motdotla/dotenv */ // import dotenv from 'dotenv'; // import path from 'path'; // dotenv.config({ path: path.resolve(__dirname, '.env') }); /** * See https://playwright.dev/docs/test-configuration. */ export default defineConfig({ testDir: "./e2e", /* Run tests in files in parallel */ fullyParallel: true, /* Fail the build on CI if you accidentally left test.only in the source code. */ forbidOnly: !!process.env.CI, /* Retry on CI only */ retries: process.env.CI ? 2 : 0, /* Opt out of parallel tests on CI. */ workers: process.env.CI ? 1 : undefined, /* Reporter to use. See https://playwright.dev/docs/test-reporters */ reporter: "html", /* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */ use: { /* Base URL to use in actions like `await page.goto('/')`. */ // baseURL: 'http://127.0.0.1:3000', /* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */ trace: "on-first-retry", }, /* Configure projects for major browsers */ projects: [ { name: "chromium", use: { ...devices["Desktop Chrome"] }, }, /* { name: 'firefox', use: { ...devices['Desktop Firefox'] }, }, { name: 'webkit', use: { ...devices['Desktop Safari'] }, }, */ /* Test against mobile viewports. */ // { // name: 'Mobile Chrome', // use: { ...devices['Pixel 5'] }, // }, // { // name: 'Mobile Safari', // use: { ...devices['iPhone 12'] }, // }, /* Test against branded browsers. */ // { // name: 'Microsoft Edge', // use: { ...devices['Desktop Edge'], channel: 'msedge' }, // }, // { // name: 'Google Chrome', // use: { ...devices['Desktop Chrome'], channel: 'chrome' }, // }, ], /* Run your local dev server before starting the tests */ webServer: { command: "pnpm run dev:react", url: "http://localhost:5131", reuseExistingServer: !process.env.CI, }, }); ================================================ FILE: Frontend/postcss.config.js ================================================ export default { plugins: { tailwindcss: {}, autoprefixer: {}, }, } ================================================ FILE: Frontend/src/app/App.tsx ================================================ import { useMemo } from "react"; import Chat from "@/components/Chat/Chat"; import { Toaster } from "@/components/ui/toaster"; import { Header } from "@/components/Header/Header"; import { useView } from "@/context/useView"; import CreateAccount from "@/components/Authentication/CreateAccount"; import SelectAccount from "@/components/Authentication/SelectAccount"; import History from "@/components/History/History"; import SettingsAlert from "@/components/AppAlert/SettingsAlert"; import { useSysSettings } from "@/context/useSysSettings"; import { useAppInitialization } from "@/hooks/useAppInitialization"; import FileExplorer from "@/components/FileExplorer/FileExplorer"; function App() { const { activeView } = useView(); const { users } = useSysSettings(); useAppInitialization(); const activeUsages = useMemo(() => { switch (activeView) { case "Chat": return ; case "History": return ; case "Signup": return ; case "SelectAccount": return ; case "FileExplorer": return ; default: return null; } }, [activeView, users]); return (
{activeUsages}
); } export default App; ================================================ FILE: Frontend/src/app/index.css ================================================ @tailwind base; @tailwind components; @tailwind utilities; @layer base { :root { --gradient: #4ecdc4; --background: 187 36.400000000000006% 4.48%; --foreground: 187 5.6000000000000005% 97.8%; --muted: 187 28.000000000000004% 16.8%; --muted-foreground: 187 5.6000000000000005% 55.6%; --popover: 187 53.8% 7.280000000000001%; --popover-foreground: 187 5.6000000000000005% 97.8%; --card: 187 53.8% 7.280000000000001%; --card-foreground: 187 5.6000000000000005% 97.8%; --border: 187 28.000000000000004% 16.8%; --input: 187 28.000000000000004% 16.8%; --primary: 187 56% 56%; --primary-foreground: 187 5.6000000000000005% 5.6000000000000005%; --secondary: 187 28.000000000000004% 16.8%; --secondary-foreground: 187 5.6000000000000005% 97.8%; --accent: 187 28.000000000000004% 16.8%; --accent-foreground: 187 5.6000000000000005% 97.8%; --destructive: 0 62.8% 30.6%; --destructive-foreground: 187 5.6000000000000005% 97.8%; --ring: 187 56% 56%; --chart-1: 220 70% 50%; --chart-2: 160 60% 45%; --chart-3: 30 80% 55%; --chart-4: 280 65% 60%; --chart-5: 340 75% 55%; } } @layer base { * { @apply border-border; } body { @apply font-sans antialiased bg-background text-foreground; } } /* Enhanced Button Styles */ .btn-provider { @apply transition-all duration-200 hover:scale-[1.02] active:scale-[0.98]; } .btn-provider.selected { @apply ring-2 ring-primary/80 ring-offset-2 ring-offset-background shadow-lg; } /* Enhanced Input Fields */ .input-field { @apply bg-secondary/5 border border-border/50 rounded-[8px] px-4 py-2.5 focus:ring-2 focus:ring-primary/30 focus:border-primary/50 transition-all duration-200; } /* Card and Container Styles */ .settings-card { @apply bg-secondary/10 backdrop-blur-sm rounded-[8px] p-6 border border-border/50 shadow-lg hover:shadow-primary/5 transition-all duration-300; } .provider-section { @apply space-y-4 bg-background/50 rounded-[8px] p-4 border border-border/50 shadow-inner; } /* Animation Effects */ @keyframes subtle-pulse { 0% { box-shadow: 0 0 0 0 rgba(var(--primary), 0.4); } 70% { box-shadow: 0 0 0 6px rgba(var(--primary), 0); } 100% { box-shadow: 0 0 0 0 rgba(var(--primary), 0); } } @keyframes dot { 0%, 20% { opacity: 0; } 50% { opacity: 1; } 100% { opacity: 0; } } .animate-pulse-subtle { animation: subtle-pulse 2s infinite; } /* Glassmorphism Effects */ .glass-panel { @apply bg-background/80 backdrop-blur-md border border-white/10 shadow-xl rounded-[8px]; } .glass-input { @apply bg-white/5 backdrop-blur-sm border border-white/10 focus:border-primary/50 focus:bg-white/10 transition-all duration-200; } /* Existing Window Control Styles */ #close { background-color: #4a4a4a; } .focus-within #close { background-color: #ff6057; } #minimize { background-color: #4a4a4a; } .focus-within #minimize { background-color: #ffbd2e; } #maximize { background-color: #4a4a4a; } .focus-within #maximize { background-color: #27c93f; } #unmaximize { background-color: #4a4a4a; } .focus-within #unmaximize { background-color: #27c93f; } /* Existing Header Styles */ .win-header-button { -webkit-app-region: no-drag; display: inline-flex; justify-content: center; align-items: center; width: 46px; height: 32px; background: transparent; border: none; outline: none; color: #fff; font-family: "Segoe MDL2 Assets", "Segoe UI", sans-serif; font-size: 10px; } .win-header-button:hover { background: rgba(255, 255, 255, 0.1); } .win-header-button.win-close:hover { background: #e81123; } .win-header-button span { font-size: 16px; line-height: 1; } .win-header-button.win-maximize { background: #0078d4; } .win-header-button.win-restore { background: #0078d4; } header { position: fixed; top: 0; left: 0; width: 100%; text-align: left; padding-inline: 2px; box-sizing: border-box; background-color: #181818; -webkit-app-region: drag; z-index: 100; } .header-button { all: unset; border-radius: 50%; width: 0.75rem; height: 0.75rem; margin: 0 0.25rem; -webkit-app-region: no-drag; } .clickable-header-section { cursor: pointer; -webkit-app-region: no-drag; } .header-button:hover { opacity: 0.8; } .window-controls:hover button span { display: block; } /* Code Styles */ code { white-space: pre-wrap !important; } /* Scrollbar Styles */ ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { @apply bg-secondary/20 rounded-full; } ::-webkit-scrollbar-thumb { @apply bg-secondary/60 rounded-full hover:bg-secondary/80 transition-colors; } ::-webkit-scrollbar-corner { @apply bg-transparent; } /* Markdown Styles */ .contentMarkdown { @apply text-foreground space-y-3; } .contentMarkdown h1 { @apply text-3xl font-bold mb-4 text-foreground/90; } .contentMarkdown h2 { @apply text-2xl font-bold mb-3 text-foreground/90; } .contentMarkdown h3 { @apply text-xl font-semibold mb-2.5 mt-4 text-foreground/90; } .contentMarkdown p { @apply mb-3 leading-relaxed text-foreground/80; } .contentMarkdown ul { @apply mb-3 space-y-1.5 list-none ml-4; } .contentMarkdown ol { @apply mb-3 space-y-4 list-none; } .contentMarkdown li { @apply relative pl-8 leading-relaxed text-foreground/80; } .contentMarkdown ul > li::before { @apply absolute left-0 top-[0.7em] w-2 h-2 rounded-full bg-primary/70 -translate-y-1/2; content: ""; } .contentMarkdown ol { counter-reset: item; } .contentMarkdown ol > li { counter-increment: item; @apply pl-0; } .contentMarkdown ol > li::before { @apply hidden; } /* Nested lists */ .contentMarkdown li > ul { @apply mt-2 mb-0 ml-4; } .contentMarkdown li > ol { @apply mt-2 mb-0; } /* Nested list items should have smaller bullets */ .contentMarkdown li > ul > li::before { @apply w-1.5 h-1.5 bg-primary/50; } .contentMarkdown li > ul > li { @apply pl-6; } .contentMarkdown code { @apply px-2 py-0.5 rounded-md bg-secondary/40 text-primary/90 font-mono text-[13px] border border-secondary/50; } .contentMarkdown strong { @apply font-semibold text-primary/90; } .contentMarkdown a { @apply text-primary/90 hover:text-primary hover:underline decoration-primary/30 transition-colors duration-200; } .contentMarkdown blockquote { @apply border-l-4 border-primary/40 pl-4 italic my-4 text-foreground/70 bg-secondary/20 py-2 pr-3 rounded-r-lg; } .contentMarkdown pre { @apply p-4 rounded-lg bg-secondary/30 overflow-x-auto my-4 border border-secondary/50; } .contentMarkdown pre code { @apply bg-transparent p-0 text-foreground/90 border-0; } .contentMarkdown ul ul, .contentMarkdown ol ol, .contentMarkdown ul ol, .contentMarkdown ol ul { @apply mt-2 mb-0; } .contentMarkdown li > p { @apply inline; } ================================================ FILE: Frontend/src/app/main.tsx ================================================ import { StrictMode } from "react"; import { createRoot } from "react-dom/client"; import "./index.css"; import App from "./App"; import UserClientProviders from "@/context/UserClientProviders"; createRoot(document.getElementById("root")!).render( ); ================================================ FILE: Frontend/src/app/vite-env.d.ts ================================================ import { defineConfig } from "vite"; import react from "@vitejs/plugin-react"; import path from "path"; // https://vitejs.dev/config/ export default defineConfig({ plugins: [react()], base: "./", resolve: { alias: { "@": path.resolve(__dirname, "./src"), "@/ui": path.resolve(__dirname, "./src/app"), "@/components": path.resolve(__dirname, "./src/components"), }, }, }); ================================================ FILE: Frontend/src/components/AppAlert/SettingsAlert.tsx ================================================ import { Dialog, DialogContent, DialogTitle, DialogDescription, } from "@/components/ui/dialog"; import { useUser } from "@/context/useUser"; import LLMPanel from "@/components/SettingsModal/SettingsComponents/LLMPanel"; export default function SettingsAlert() { const { alertForUser, setAlertForUser } = useUser(); return ( LLM Settings Please add an API key or Select Local Model Deployment
*Local Model Deployment requires Ollama to be installed and running
); } ================================================ FILE: Frontend/src/components/Authentication/CreateAccount.tsx ================================================ import { Label } from "@/components/ui/label"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardHeader, CardTitle, CardDescription, CardFooter, } from "@/components/ui/card"; import { Input } from "@/components/ui/input"; import { useView } from "@/context/useView"; import { useState, useEffect } from "react"; import { motion, AnimatePresence } from "framer-motion"; import { useSysSettings } from "@/context/useSysSettings"; import { useUser } from "@/context/useUser"; export default function CreateAccount() { const { setActiveView } = useView(); const { users, setUsers } = useSysSettings(); const { setActiveUser } = useUser(); const [accountName, setAccountName] = useState(""); const [error, setError] = useState(""); const [currentStep, setCurrentStep] = useState(0); const steps = [ { title: "Welcome to Notate", subtitle: null }, { title: "Your Research Hack Tool", subtitle: null }, { title: "Create Account", subtitle: "Enter your name to create an account", }, ]; useEffect(() => { const timer = setTimeout(() => { if (currentStep < steps.length - 1) { setCurrentStep((prevStep) => prevStep + 1); } }, 1500); return () => clearTimeout(timer); }, [currentStep, steps.length]); const handleCreateAccount = async () => { if (!accountName.trim()) { setError("Please enter a name"); return; } try { const user = await window.electron.addUser(accountName); if (user.error) { setError(user.error); return; } const allUsers = (await window.electron.getUsers()).users; const activeUser = allUsers.find((u) => u.name === user.name); if (activeUser) { setActiveUser(activeUser); setUsers(allUsers); setActiveView("Chat"); } else { setError("Failed to create account. Please try again."); } } catch (err) { setError("Failed to create account. Please try again."); console.error(err); } }; const handleBack = () => setActiveView("SelectAccount"); const fadeInUp = { initial: { opacity: 0, y: 20 }, animate: { opacity: 1, y: 0 }, exit: { opacity: 0, y: -20 }, transition: { duration: 0.5 }, }; return (
{currentStep < 2 ? ( <>

{steps[currentStep].title}

{steps[currentStep].subtitle && (

{steps[currentStep].subtitle}

)} ) : ( {steps[currentStep].title} {steps[currentStep].subtitle}
setAccountName(e.target.value)} />
{error &&

{error}

}
{users.length > 0 && ( )}
)}
); } ================================================ FILE: Frontend/src/components/Authentication/SelectAccount.tsx ================================================ import { Avatar, AvatarFallback } from "@/components/ui/avatar"; import { Card, CardContent } from "@/components/ui/card"; import { useView } from "@/context/useView"; import { Button } from "@/components/ui/button"; import { useUser } from "@/context/useUser"; import { useSysSettings } from "@/context/useSysSettings"; import { motion } from "framer-motion"; import { Plus } from "lucide-react"; import { ScrollArea } from "@/components/ui/scroll-area"; const containerVariants = { hidden: { opacity: 0 }, visible: { opacity: 1, transition: { duration: 0.5, staggerChildren: 0.1 } }, }; const itemVariants = { hidden: { opacity: 0, y: 20 }, visible: { opacity: 1, y: 0 }, }; const MotionAvatar = motion.create(Avatar); export default function SelectAccount({ users }: { users: User[] }) { const { setActiveView } = useView(); const { activeUser, setActiveUser } = useUser(); const { setSettings } = useSysSettings(); const fetchSettings = async () => { if (activeUser) { const userSettings = await window.electron.getUserSettings(activeUser.id); setSettings(userSettings); } }; const handleSelectAccount = (user: User) => { setActiveUser(user); fetchSettings(); setActiveView("Chat"); }; return (
Select Account
Choose your account to access your workspace
{users.map((user) => ( handleSelectAccount(user)} > {user.name.charAt(0).toUpperCase()}

{user.name}

Click to access workspace

))}
); } ================================================ FILE: Frontend/src/components/Chat/Chat.tsx ================================================ import { ArrowDown, Loader2 } from "lucide-react"; import { Button } from "@/components/ui/button"; import { useUser } from "@/context/useUser"; import { useChatInput } from "@/context/useChatInput"; import { useChatLogic } from "@/hooks/useChatLogic"; import { ChatMessagesArea } from "./ChatComponents/ChatMessagesArea"; import { ChatInput } from "./ChatComponents/ChatInput"; import { LoadingIndicator } from "./ChatComponents/LoadingIndicator"; import { useSysSettings } from "@/context/useSysSettings"; import { IngestProgress } from "../CollectionModals/CollectionComponents/IngestProgress"; export default function Chat() { const { scrollAreaRef, resetCounter, bottomRef, showScrollButton, scrollToBottom, } = useChatLogic(); const { localModalLoading } = useSysSettings(); const { streamingMessage, streamingMessageReasoning, messages, error } = useUser(); const { isLoading } = useChatInput(); return (
{showScrollButton && ( )}
{localModalLoading && (
Loading local model...
)}
{isLoading && (
)}
); } ================================================ FILE: Frontend/src/components/Chat/ChatComponents/ChatHeader.tsx ================================================ import { Button } from "@/components/ui/button"; import { PlusCircle } from "lucide-react"; import { Loader2 } from "lucide-react"; import { IngestProgress } from "@/components/CollectionModals/CollectionComponents/IngestProgress"; import logo from "@/assets/icon.png"; import { useSysSettings } from "@/context/useSysSettings"; import { useChatLogic } from "@/hooks/useChatLogic"; export function ChatHeader() { const { localModalLoading } = useSysSettings(); const { handleResetChat } = useChatLogic(); return (
logo

Notate

{localModalLoading && (
Loading local model...
)}
); } ================================================ FILE: Frontend/src/components/Chat/ChatComponents/ChatInput.tsx ================================================ import { LibraryModal } from "@/components/CollectionModals/LibraryModal"; import { Button } from "@/components/ui/button"; import { Sheet, SheetTitle, SheetHeader, SheetContent, SheetDescription, SheetTrigger, } from "@/components/ui/sheet"; import { Textarea } from "@/components/ui/textarea"; import { Library, Send, X, Mic, Loader2, Globe } from "lucide-react"; import { useState, useEffect, useMemo, useCallback, memo } from "react"; import { useUser } from "@/context/useUser"; import { useChatInput } from "@/context/useChatInput"; import { useSysSettings } from "@/context/useSysSettings"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger, } from "@/components/ui/tooltip"; import { WebAudioRecorder } from "@/utils/webAudioRecorder"; import { useLibrary } from "@/context/useLibrary"; export const ChatInput = memo(function ChatInput() { const { activeUser, toggleTool, userTools, activeConversation } = useUser(); const { handleChatRequest, cancelRequest, input, setInput, isLoading, setIsLoading, } = useChatInput(); const { openLibrary, setOpenLibrary } = useLibrary(); const [isRecording, setIsRecording] = useState(false); const [transcriptionLoading, setTranscriptionLoading] = useState(false); const [loadingDots, setLoadingDots] = useState(""); const { isFFMPEGInstalled } = useSysSettings(); const audioRecorder = useMemo(() => new WebAudioRecorder(), []); const { selectedCollection } = useLibrary(); // Memoize the loading dots animation interval useEffect(() => { if (!transcriptionLoading) { setLoadingDots(""); return; } const interval = setInterval(() => { setLoadingDots((prev: string) => (prev === "..." ? "" : prev + ".")); }, 500); return () => clearInterval(interval); }, [transcriptionLoading]); // Memoize the recording handler const handleRecording = useCallback(async () => { try { if (!isRecording) { await audioRecorder.startRecording(); setIsRecording(true); } else { setTranscriptionLoading(true); const audioData = await audioRecorder.stopRecording(); setIsRecording(false); if (!activeUser?.id) { console.error("No active user ID found"); setTranscriptionLoading(false); return; } const result = await window.electron.transcribeAudio( audioData, activeUser.id ); if (!result.success) { console.error("Failed to transcribe audio:", result.error); setTranscriptionLoading(false); return; } if (result.transcription) { setInput((prev) => { const newInput = prev + (prev ? " " : "") + result.transcription; return newInput; }); } else { console.warn("No transcription in result:", result); } setTranscriptionLoading(false); } } catch (error) { console.error("Error handling recording:", error); setIsRecording(false); setTranscriptionLoading(false); } }, [isRecording, audioRecorder, activeUser?.id, setInput]); // Memoize the tooltip content const tooltipContent = useMemo(() => { if (!isFFMPEGInstalled) return "Please install FFMPEG to use voice-to-text"; if (transcriptionLoading) return "Transcribing your audio..."; if (isRecording) return "Click to stop recording"; return "Click to start voice recording"; }, [isFFMPEGInstalled, transcriptionLoading, isRecording]); // Memoize the form submit handler const handleSubmit = useCallback( (e: React.FormEvent) => { e.preventDefault(); if (input.trim()) { handleChatRequest(selectedCollection?.id || undefined); } }, [input, handleChatRequest, selectedCollection?.id] ); // Memoize the send button handler const handleSendClick = useCallback(async () => { if (isLoading) { cancelRequest(); setIsLoading(false); } else if (input.trim()) { await handleChatRequest(selectedCollection?.id || undefined); } }, [ isLoading, input, cancelRequest, handleChatRequest, selectedCollection?.id, setIsLoading, ]); return (
Data Store Library

{tooltipContent}

{userTools.map((tool) => ( ))}