| import asyncio |
| import json |
| import os |
| import uuid |
| from typing import AsyncIterator, Dict, Any |
| import aiohttp |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SSEClient: |
| """Async SSE client for streaming chat API requests""" |
| |
| def __init__(self): |
| self.url = os.getenv("API_ENDPOINT") |
| self.headers = { |
| 'Content-Type': 'application/json', |
| 'User-Agent': 'HuggingFace-Gradio-Demo' |
| } |
|
|
| async def stream_chat(self, query: str, |
| deep_thinking_mode: bool = False, |
| search_before_planning: bool = False, |
| debug: bool = False, |
| chat_id: str = None) -> AsyncIterator[Dict[str, Any]]: |
| """ |
| Async request to SSE interface and return streaming data with event parsing |
| |
| Args: |
| query: User query content |
| deep_thinking_mode: Whether to enable deep thinking mode, default False |
| search_before_planning: Whether to search before planning, default False |
| debug: Whether to enable debug mode, default False |
| chat_id: Chat ID, will be auto-generated if not provided |
| |
| Yields: |
| Dict[str, Any]: SSE event data with 'event' and 'data' fields |
| """ |
| if chat_id is None: |
| chat_id = self._generate_chat_id() |
| |
| |
| data = { |
| "messages": [{ |
| "id": chat_id, |
| "role": "user", |
| "type": "text", |
| "content": query |
| }], |
| "deep_thinking_mode": deep_thinking_mode, |
| "search_before_planning": search_before_planning, |
| "debug": debug, |
| "chatId": chat_id |
| } |
| |
| async with aiohttp.ClientSession( |
| timeout=aiohttp.ClientTimeout(total=None) |
| ) as session: |
| try: |
| async with session.post( |
| self.url, |
| headers=self.headers, |
| json=data |
| ) as response: |
| if response.status != 200: |
| raise Exception(f"Request failed with status code: {response.status}") |
| |
| |
| current_event = None |
| |
| async for line in response.content: |
| line = line.decode('utf-8').strip() |
| if line: |
| if line.startswith('event: '): |
| |
| current_event = line[7:] |
| elif line.startswith('data: '): |
| |
| data_content = line[6:] |
| if data_content and data_content != '[DONE]': |
| |
| yield { |
| 'event': current_event or 'message', |
| 'data': data_content |
| } |
| |
| current_event = None |
| elif line == '': |
| |
| current_event = None |
| else: |
| |
| yield { |
| 'event': current_event or 'data', |
| 'data': line |
| } |
| current_event = None |
| |
| except asyncio.CancelledError: |
| |
| raise |
| except Exception as e: |
| raise Exception(f"SSE request error: {str(e)}") |
|
|
| def _generate_chat_id(self) -> str: |
| """Generate chat ID""" |
| return str(uuid.uuid4()).replace('-', '')[:21] |
|
|
| async def stream_chat_parsed(self, query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: |
| """ |
| Async request to SSE interface and return parsed JSON data with event structure |
| |
| Args: |
| query: User query content |
| **kwargs: Other parameters passed to stream_chat |
| |
| Yields: |
| Dict[str, Any]: Event data with 'event' and 'data' fields, where 'data' contains parsed JSON |
| """ |
| async for event_data in self.stream_chat(query, **kwargs): |
| try: |
| |
| parsed_data = json.loads(event_data['data']) |
| yield { |
| 'event': event_data['event'], |
| 'data': parsed_data |
| } |
| except json.JSONDecodeError: |
| |
| yield event_data |
| except (KeyError, TypeError): |
| |
| continue |
|
|
|
|
| |
| async def request_sse_stream(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: |
| """ |
| Convenience function: Async request to SSE interface and return raw event data |
| |
| Args: |
| query: User query content |
| **kwargs: Other parameters |
| |
| Yields: |
| Dict[str, Any]: Raw event data with 'event' and 'data' fields (data as string) |
| """ |
| client = SSEClient() |
| async for event_data in client.stream_chat(query, **kwargs): |
| yield event_data |
|
|
|
|
| async def request_sse_stream_parsed(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: |
| """ |
| Convenience function: Async request to SSE interface and return structured event data |
| |
| Args: |
| query: User query content |
| **kwargs: Other parameters |
| |
| Yields: |
| Dict[str, Any]: Event data with 'event' and 'data' fields |
| """ |
| client = SSEClient() |
| async for event_data in client.stream_chat_parsed(query, **kwargs): |
| yield event_data |
|
|
|
|
| async def stop_chat(chat_id: str): |
| url = f"{os.getenv('STOP_CHAT_API_ENDPOINT')}" |
| async with aiohttp.ClientSession() as session: |
| async with session.post(url, json={"chatId": chat_id}) as response: |
| if response.status != 200: |
| logger.error(f"Request failed with status code: {response.status}") |
| raise Exception(f"Request failed with status code: {response.status}") |
| return await response.json() |
|
|
| |
| async def main(): |
| """Example usage method""" |
| query = "Hello" |
| |
| print("=== SSE Event Stream ===") |
| async for event_data in request_sse_stream_parsed(query): |
| event_type = event_data.get('event', 'unknown') |
| data_content = event_data.get('data', {}) |
| print(f"Event: {event_type}") |
| print(f"Data: {data_content}") |
| print("-" * 40) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|