diff --git a/src/main/java/cn/com/tenlion/operator/util/XinghuoAssistantClient.java b/src/main/java/cn/com/tenlion/operator/util/XinghuoAssistantClient.java new file mode 100644 index 0000000..6e634c5 --- /dev/null +++ b/src/main/java/cn/com/tenlion/operator/util/XinghuoAssistantClient.java @@ -0,0 +1,241 @@ +package cn.com.tenlion.operator.util; + +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSON; +import com.alibaba.fastjson2.JSONObject; +import okhttp3.*; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.concurrent.TimeUnit; + +public class XinghuoAssistantClient { + private final OkHttpClient client; + private final String appId; + private final String apiKey; + private final String apiSecret; // 新增:星火鉴权必须的API Secret + private final String assistantId; + private final String wsUrl; // WebSocket接口地址 + + // 构造方法:新增apiSecret参数(你的Secret是ZGI1M2M4MDBlNjRhNDhhMTg1YTM4OWE1) + public XinghuoAssistantClient(String appId, String apiKey, String apiSecret, String assistantId) { + this.appId = appId; + this.apiKey = apiKey; + this.apiSecret = apiSecret; + this.assistantId = assistantId; + this.wsUrl = "wss://spark-openapi.cn-huabei-1.xf-yun.com/v1/assistants/" + assistantId; + + // 初始化OkHttp客户端(保持超时配置) + this.client = new OkHttpClient.Builder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .build(); + } + + /** + * 向星火助手发送消息并获取AI回复(支持流式分段响应) + * @param userQuery 用户输入的问题文本 + * @return 完整的AI回复 + * @throws Exception 网络异常、鉴权失败等 + */ + public String askQuestion(String userQuery) throws Exception { + final StringBuilder aiResponseBuilder = new StringBuilder(); // 拼接流式响应 + final boolean[] isResponseComplete = {false}; // 标记响应是否完成 + final boolean[] isWebSocketClosed = {false}; // 标记连接是否关闭 + final Exception[] requestException = {null}; // 捕获回调中的异常 + + // 1. 生成鉴权头(核心:替换Bearer Token) + String authHeader = generateAuthHeader(); + + // 2. 构建请求体JSON(已修正domain为generalv3) + String requestBodyJson = buildRequestBodyJson(userQuery); + + // 3. 创建WebSocket请求(无RequestBody,仅携带鉴权头) + Request request = new Request.Builder() + .url(wsUrl) + .addHeader("Authorization", authHeader) // 使用签名鉴权 + .addHeader("Content-Type", "application/json") + .build(); + + WebSocket webSocket = null; + try { + // 4. 建立WebSocket连接,在onOpen中发送请求体 + webSocket = client.newWebSocket(request, new WebSocketListener() { + @Override + public void onOpen(WebSocket webSocket, Response response) { + System.out.println("✅ WebSocket connected successfully!"); + // 连接建立后,发送请求体(关键修正:这里才是正确的发送时机) + webSocket.send(requestBodyJson); + } + + @Override + public void onMessage(WebSocket webSocket, String text) { + try { + // 解析星火的响应(支持流式分段返回,status=2表示响应完成) + JSONObject jsonObj = JSON.parseObject(text); + int headerCode = jsonObj.getJSONObject("header").getIntValue("code"); + + // 先判断服务端是否返回错误(如鉴权失败、参数错误) + if (headerCode != 0) { + String errorMsg = jsonObj.getJSONObject("header").getString("message"); + requestException[0] = new RuntimeException("Server error: " + headerCode + " - " + errorMsg); + webSocket.close(1001, "Server error"); + return; + } + + // 解析AI回复(流式返回可能分多段,需拼接) + JSONObject payload = jsonObj.getJSONObject("payload"); + if (payload != null && payload.containsKey("choices")) { + JSONObject choices = payload.getJSONObject("choices"); + int choicesStatus = choices.getIntValue("status"); // status=1:分段,status=2:完成 + JSONArray textArray = choices.getJSONArray("text"); + + for (int i = 0; i < textArray.size(); i++) { + JSONObject item = textArray.getJSONObject(i); + if ("assistant".equals(item.getString("role"))) { + String segment = item.getString("content"); + aiResponseBuilder.append(segment); // 拼接分段内容 + System.out.println("💬 AI Segment: " + segment); + } + } + + // 响应完成(status=2),标记结束 + if (choicesStatus == 2) { + isResponseComplete[0] = true; + System.out.println("✅ AI response completed!"); + } + } + } catch (Exception e) { + requestException[0] = e; // 捕获解析异常 + webSocket.close(1002, "Parse error"); + } + } + + @Override + public void onClosing(WebSocket webSocket, int code, String reason) { + System.out.println("⚠️ Closing connection: code=" + code + ", reason=" + reason); + } + + @Override + public void onClosed(WebSocket webSocket, int code, String reason) { + System.out.println("🔒 Connection closed: code=" + code + ", reason=" + reason); + isWebSocketClosed[0] = true; + } + + @Override + public void onFailure(WebSocket webSocket, Throwable t, Response response) { + requestException[0] = new RuntimeException("WebSocket failure", t); // 捕获连接失败异常 + isWebSocketClosed[0] = true; + } + }); + + // 5. 等待响应完成或超时(延长超时时间到30秒,适配流式返回) + long startTime = System.currentTimeMillis(); + while (!isResponseComplete[0] && !isWebSocketClosed[0] && requestException[0] == null) { + Thread.sleep(100); // 每100ms检查一次状态 + long elapsed = System.currentTimeMillis() - startTime; + // 超时时间设为30秒(星火流式返回可能需要较长时间) + if (elapsed > 30 * 1000) { + throw new RuntimeException("Timeout waiting for AI response (elapsed: " + elapsed + "ms)"); + } + } + + // 6. 检查是否有异常(如鉴权失败、解析错误) + if (requestException[0] != null) { + throw requestException[0]; + } + + // 返回完整的AI回复 + return aiResponseBuilder.toString(); + } finally { + // 关闭WebSocket连接(避免资源泄漏) + if (webSocket != null && !isWebSocketClosed[0] ) { + webSocket.close(1000, "Normal closure"); + } + } + } + + /** + * 生成星火API要求的鉴权头(核心:HMAC-SHA256签名) + * 鉴权逻辑参考:https://www.xfyun.cn/doc/spark/Web.html#_2-1-%E9%89%B4%E6%9D%83%E8%AF%B4%E6%98%8E + */ + private String generateAuthHeader() throws Exception { + // 1. 获取当前时间戳(秒级) + String timestamp = String.valueOf(System.currentTimeMillis() / 1000); + // 2. 拼接签名原文(格式:apiKey + timestamp) + String signatureOrigin = apiKey + timestamp; + // 3. HMAC-SHA256加密(密钥是apiSecret,编码UTF-8) + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "HmacSHA256")); + byte[] signatureBytes = mac.doFinal(signatureOrigin.getBytes(StandardCharsets.UTF_8)); + // 4. Base64编码签名结果 + String signature = Base64.getEncoder().encodeToString(signatureBytes); + // 5. 拼接鉴权头(格式:api_key="xxx",timestamp="xxx",signature="xxx") + return String.format("api_key=\"%s\",timestamp=\"%s\",signature=\"%s\"", + apiKey, timestamp, signature); + } + + /** + * 构建符合星火API规范的请求体JSON(已修正domain为generalv3) + */ + private String buildRequestBodyJson(String userQuery) { + // 1. header部分(必传appId,uid可选) + JSONObject header = new JSONObject(); + header.put("app_id", appId); + header.put("uid", "user_123456"); // 建议传唯一用户ID(便于问题排查) + + // 2. parameter部分(domain修正为generalv3,其他参数保持默认) + JSONObject chatParams = new JSONObject(); + chatParams.put("domain", "generalv3"); // 关键修正:通用模型必须传generalv3 + chatParams.put("temperature", 0.5); // 采样阈值(0-1,值越大越随机) + chatParams.put("top_k", 4); // TopK选择(1-6,默认4) + chatParams.put("max_tokens", 2048); // 最大回复长度(1-4096,默认2048) + + JSONObject parameter = new JSONObject(); + parameter.put("chat", chatParams); + + // 3. payload部分(用户问题,role固定为user) + JSONArray textArray = new JSONArray(); + JSONObject userMsg = new JSONObject(); + userMsg.put("role", "user"); + userMsg.put("content", userQuery); + textArray.add(userMsg); + + JSONObject message = new JSONObject(); + message.put("text", textArray); + + JSONObject payload = new JSONObject(); + payload.put("message", message); + + // 4. 组装根JSON + JSONObject rootJson = new JSONObject(); + rootJson.put("header", header); + rootJson.put("parameter", parameter); + rootJson.put("payload", payload); + + return rootJson.toJSONString(); + } + + // 主方法:测试调用(替换为你的实际参数) + public static void main(String[] args) { + try { + // 你的实际参数(从星火平台获取) + String appId = "a9514070"; + String apiKey = "20a8e5c33dc8694925d4a5777cb19d53"; + String apiSecret = "ZGI1M2M4MDBlNjRhNDhhMTg1YTM4OWE1"; // 你的API Secret + String assistantId = "hz96ujdxn91w_v1"; + + // 创建客户端并调用 + XinghuoAssistantClient client = new XinghuoAssistantClient(appId, apiKey, apiSecret, assistantId); + String answer = client.askQuestion("今天的天气怎么样?"); + + // 打印最终结果 + System.out.println("\n📌 Final AI Answer:\n" + answer); + } catch (Exception e) { + System.err.println("❌ Call failed: "); + e.printStackTrace(); + } + } +}