parent
ca470fdb17
commit
0c8326d222
@ -0,0 +1,162 @@ |
||||
package com.cjy.traceability.module.traceability.controller.admin.ai; |
||||
|
||||
import cn.hutool.core.convert.Convert; |
||||
import cn.hutool.core.util.IdUtil; |
||||
import com.alibaba.druid.util.StringUtils; |
||||
import com.alibaba.fastjson.JSON; |
||||
import com.alibaba.fastjson.JSONArray; |
||||
import com.cjy.traceability.framework.common.enums.UserTypeEnum; |
||||
import com.cjy.traceability.framework.common.pojo.CommonResult; |
||||
import com.cjy.traceability.framework.security.core.LoginUser; |
||||
import com.cjy.traceability.framework.security.core.util.SecurityFrameworkUtils; |
||||
import com.cjy.traceability.module.infra.api.websocket.WebSocketSenderApi; |
||||
import com.cjy.traceability.module.traceability.controller.admin.ai.util.RedisCacheUtil; |
||||
import com.cjy.traceability.module.traceability.controller.admin.ai.vo.LargeModeVO; |
||||
import jodd.util.StringUtil; |
||||
import okhttp3.*; |
||||
import okio.BufferedSource; |
||||
import okio.Okio; |
||||
import org.json.JSONObject; |
||||
import org.springframework.beans.factory.annotation.Autowired; |
||||
import org.springframework.web.bind.annotation.GetMapping; |
||||
import org.springframework.web.bind.annotation.PostMapping; |
||||
import org.springframework.web.bind.annotation.RequestMapping; |
||||
import org.springframework.web.bind.annotation.RestController; |
||||
|
||||
import javax.annotation.Resource; |
||||
import javax.servlet.http.HttpServletRequest; |
||||
import java.io.IOException; |
||||
import java.util.ArrayList; |
||||
import java.util.HashMap; |
||||
import java.util.List; |
||||
import java.util.Map; |
||||
import java.util.concurrent.TimeUnit; |
||||
|
||||
import static com.cjy.traceability.framework.common.pojo.CommonResult.success; |
||||
|
||||
/** |
||||
* @ClassName: LargeModelAIController |
||||
* @Author: zc |
||||
* @Date: 2024/4/18 14:09 |
||||
* @Description: TODO |
||||
*/ |
||||
@RestController |
||||
@RequestMapping("/largeMode/ai") |
||||
public class LargeModelAIController { |
||||
|
||||
@Resource |
||||
private WebSocketSenderApi webSocketSenderApi; |
||||
@Autowired |
||||
private RedisCacheUtil redisCache; |
||||
|
||||
public static final String API_KEY = "psK0ClhB9qtPl4KvqU6eCmQZ"; |
||||
public static final String SECRET_KEY = "rFaaEspsFBWhicoEBDZuxPaYIREqvVJr"; |
||||
//百度大模型接口ERNIE-Bot 4.0地址
|
||||
public static final String URL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"; |
||||
//ERNIE-Bot-turbo地址
|
||||
//public static final String URL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant";
|
||||
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().connectTimeout(60, TimeUnit.SECONDS).writeTimeout(60, TimeUnit.SECONDS).readTimeout(60, TimeUnit.SECONDS).build(); |
||||
|
||||
|
||||
@PostMapping("/sendMsg") |
||||
public CommonResult<Boolean> sendMsg(HttpServletRequest request, @org.springframework.web.bind.annotation.RequestBody LargeModeVO largeModeVO) throws IOException { |
||||
Map<String,String> map = new HashMap<>(); |
||||
LoginUser loginUser = SecurityFrameworkUtils.getLoginUser(); |
||||
String message = largeModeVO.getMessage(); |
||||
String dialogId = largeModeVO.getDialogId(); |
||||
String identity = largeModeVO.getIdentity(); |
||||
String redisContext = ""; |
||||
if(StringUtils.isEmpty(dialogId)){ |
||||
dialogId = IdUtil.simpleUUID(); |
||||
redisContext = "{\"role\":\"user\",\"content\":\"" + message + "\"}"; |
||||
}else{ |
||||
redisContext = Convert.toStr(redisCache.getCacheObject("AI_"+identity+"_"+dialogId)); |
||||
List<String> dataList = splitRedisString(redisContext); |
||||
if(dataList.size() >= 20){ |
||||
dataList = dataList.subList(2, dataList.size()); |
||||
redisContext = StringUtil.join(dataList, ","); |
||||
} |
||||
if(StringUtils.isEmpty(message)){ |
||||
map.put("context", "消息内容为空"); |
||||
map.put("dialogId", dialogId); |
||||
map.put("identity", identity); |
||||
webSocketSenderApi.sendObject(UserTypeEnum.ADMIN.getValue(),loginUser.getId(), "ai-push", map); |
||||
|
||||
return success(false); |
||||
} |
||||
redisContext += ",{\"role\":\"user\",\"content\":\"" + message + "\"}"; |
||||
} |
||||
|
||||
redisContext = redisContext.replaceAll("[\\r\\n\\s]+",""); |
||||
String accessToken = getAccessToken(); |
||||
MediaType mediaType = MediaType.parse("application/json"); |
||||
String requestContent = "{\"messages\":["+redisContext+"],\"stream\":true"; |
||||
String sysMsg = "你是农产品质量安全追溯管理信息平台的AI助手,当有人问你的相关信息时,不要回答你跟百度的关系,你就是这个平台的AI助手。"; |
||||
requestContent+=",\"system\":\""+sysMsg+"\""; //人设内容
|
||||
requestContent+="}"; |
||||
RequestBody body = RequestBody.create(mediaType, requestContent); |
||||
Request requests = new Request.Builder() |
||||
.url(URL+"?access_token=" + accessToken) |
||||
.method("POST", body) |
||||
.addHeader("Content-Type", "application/json") |
||||
.build(); |
||||
try (Response resultResponse = HTTP_CLIENT.newCall(requests).execute()) { |
||||
BufferedSource source = Okio.buffer(resultResponse.body().source()); |
||||
String line = ""; |
||||
String redisResult = ""; |
||||
while ((line = source.readUtf8Line()) != null) { |
||||
if (StringUtils.isEmpty(line)){ |
||||
continue; |
||||
} |
||||
String resultStr = line.replace("data: ", ""); |
||||
JSONObject jsonObject = new JSONObject(resultStr); |
||||
String result = jsonObject.get("result").toString(); |
||||
redisResult += result; |
||||
map.put("context", result); |
||||
map.put("dialogId", dialogId); |
||||
map.put("identity", identity); |
||||
//防止返回内容重复或错误。暂停一下
|
||||
Thread.sleep(100); |
||||
// 通过 websocket 推送用户
|
||||
webSocketSenderApi.sendObject(UserTypeEnum.ADMIN.getValue(),loginUser.getId(), "ai-push", map); |
||||
} |
||||
redisContext += ",{\"role\":\"assistant\",\"content\":\"" + redisResult + "\"}"; |
||||
redisCache.setCacheObject("AI_"+identity+"_"+dialogId, redisContext); |
||||
redisCache.expire("AI_"+identity+"_"+dialogId,10000); |
||||
return success(true); |
||||
} catch (Exception e) { |
||||
e.printStackTrace(); |
||||
return success(false); |
||||
} |
||||
} |
||||
|
||||
// 分割 Redis 字符串
|
||||
private static List<String> splitRedisString(String redisContext) { |
||||
List<String> dataList = new ArrayList<>(); |
||||
JSONArray jsonArray = JSON.parseArray("[" + redisContext + "]"); |
||||
for (int i = 0; i < jsonArray.size(); i++) { |
||||
dataList.add(jsonArray.getString(i)); |
||||
} |
||||
return dataList; |
||||
} |
||||
|
||||
/** |
||||
* 从用户的AK,SK生成鉴权签名(Access Token) |
||||
* |
||||
* @return 鉴权签名(Access Token) |
||||
* @throws IOException IO异常 |
||||
*/ |
||||
static String getAccessToken() throws IOException { |
||||
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); |
||||
RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY |
||||
+ "&client_secret=" + SECRET_KEY); |
||||
Request request = new Request.Builder() |
||||
.url("https://aip.baidubce.com/oauth/2.0/token") |
||||
.method("POST", body) |
||||
.addHeader("Content-Type", "application/x-www-form-urlencoded") |
||||
.build(); |
||||
Response response = HTTP_CLIENT.newCall(request).execute(); |
||||
return new JSONObject(response.body().string()).getString("access_token"); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,265 @@ |
||||
package com.cjy.traceability.module.traceability.controller.admin.ai.util; |
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired; |
||||
import org.springframework.data.redis.core.BoundSetOperations; |
||||
import org.springframework.data.redis.core.HashOperations; |
||||
import org.springframework.data.redis.core.RedisTemplate; |
||||
import org.springframework.data.redis.core.ValueOperations; |
||||
import org.springframework.stereotype.Component; |
||||
|
||||
import java.util.*; |
||||
import java.util.concurrent.TimeUnit; |
||||
|
||||
/** |
||||
* spring redis 工具类 |
||||
* |
||||
* @author ruoyi |
||||
**/ |
||||
@SuppressWarnings(value = { "unchecked", "rawtypes" }) |
||||
@Component |
||||
public class RedisCacheUtil |
||||
{ |
||||
@Autowired |
||||
public RedisTemplate redisTemplate; |
||||
|
||||
/** |
||||
* 缓存基本的对象,Integer、String、实体类等 |
||||
* |
||||
* @param key 缓存的键值 |
||||
* @param value 缓存的值 |
||||
*/ |
||||
public <T> void setCacheObject(final String key, final T value) |
||||
{ |
||||
redisTemplate.opsForValue().set(key, value); |
||||
} |
||||
|
||||
/** |
||||
* 缓存基本的对象,Integer、String、实体类等 |
||||
* |
||||
* @param key 缓存的键值 |
||||
* @param value 缓存的值 |
||||
* @param timeout 时间 |
||||
* @param timeUnit 时间颗粒度 |
||||
*/ |
||||
public <T> void setCacheObject(final String key, final T value, final Integer timeout, final TimeUnit timeUnit) |
||||
{ |
||||
redisTemplate.opsForValue().set(key, value, timeout, timeUnit); |
||||
} |
||||
|
||||
/** |
||||
* 设置有效时间 |
||||
* |
||||
* @param key Redis键 |
||||
* @param timeout 超时时间 |
||||
* @return true=设置成功;false=设置失败 |
||||
*/ |
||||
public boolean expire(final String key, final long timeout) |
||||
{ |
||||
return expire(key, timeout, TimeUnit.SECONDS); |
||||
} |
||||
|
||||
/** |
||||
* 设置有效时间 |
||||
* |
||||
* @param key Redis键 |
||||
* @param timeout 超时时间 |
||||
* @param unit 时间单位 |
||||
* @return true=设置成功;false=设置失败 |
||||
*/ |
||||
public boolean expire(final String key, final long timeout, final TimeUnit unit) |
||||
{ |
||||
return redisTemplate.expire(key, timeout, unit); |
||||
} |
||||
|
||||
/** |
||||
* 获取有效时间 |
||||
* |
||||
* @param key Redis键 |
||||
* @return 有效时间 |
||||
*/ |
||||
public long getExpire(final String key) |
||||
{ |
||||
return redisTemplate.getExpire(key); |
||||
} |
||||
|
||||
/** |
||||
* 判断 key是否存在 |
||||
* |
||||
* @param key 键 |
||||
* @return true 存在 false不存在 |
||||
*/ |
||||
public Boolean hasKey(String key) |
||||
{ |
||||
return redisTemplate.hasKey(key); |
||||
} |
||||
|
||||
/** |
||||
* 获得缓存的基本对象。 |
||||
* |
||||
* @param key 缓存键值 |
||||
* @return 缓存键值对应的数据 |
||||
*/ |
||||
public <T> T getCacheObject(final String key) |
||||
{ |
||||
ValueOperations<String, T> operation = redisTemplate.opsForValue(); |
||||
return operation.get(key); |
||||
} |
||||
|
||||
/** |
||||
* 删除单个对象 |
||||
* |
||||
* @param key |
||||
*/ |
||||
public boolean deleteObject(final String key) |
||||
{ |
||||
return redisTemplate.delete(key); |
||||
} |
||||
|
||||
/** |
||||
* 删除集合对象 |
||||
* |
||||
* @param collection 多个对象 |
||||
* @return |
||||
*/ |
||||
public boolean deleteObject(final Collection collection) |
||||
{ |
||||
return redisTemplate.delete(collection) > 0; |
||||
} |
||||
|
||||
/** |
||||
* 缓存List数据 |
||||
* |
||||
* @param key 缓存的键值 |
||||
* @param dataList 待缓存的List数据 |
||||
* @return 缓存的对象 |
||||
*/ |
||||
public <T> long setCacheList(final String key, final List<T> dataList) |
||||
{ |
||||
Long count = redisTemplate.opsForList().rightPushAll(key, dataList); |
||||
return count == null ? 0 : count; |
||||
} |
||||
|
||||
/** |
||||
* 获得缓存的list对象 |
||||
* |
||||
* @param key 缓存的键值 |
||||
* @return 缓存键值对应的数据 |
||||
*/ |
||||
public <T> List<T> getCacheList(final String key) |
||||
{ |
||||
return redisTemplate.opsForList().range(key, 0, -1); |
||||
} |
||||
|
||||
/** |
||||
* 缓存Set |
||||
* |
||||
* @param key 缓存键值 |
||||
* @param dataSet 缓存的数据 |
||||
* @return 缓存数据的对象 |
||||
*/ |
||||
public <T> BoundSetOperations<String, T> setCacheSet(final String key, final Set<T> dataSet) |
||||
{ |
||||
BoundSetOperations<String, T> setOperation = redisTemplate.boundSetOps(key); |
||||
Iterator<T> it = dataSet.iterator(); |
||||
while (it.hasNext()) |
||||
{ |
||||
setOperation.add(it.next()); |
||||
} |
||||
return setOperation; |
||||
} |
||||
|
||||
/** |
||||
* 获得缓存的set |
||||
* |
||||
* @param key |
||||
* @return |
||||
*/ |
||||
public <T> Set<T> getCacheSet(final String key) |
||||
{ |
||||
return redisTemplate.opsForSet().members(key); |
||||
} |
||||
|
||||
/** |
||||
* 缓存Map |
||||
* |
||||
* @param key |
||||
* @param dataMap |
||||
*/ |
||||
public <T> void setCacheMap(final String key, final Map<String, T> dataMap) |
||||
{ |
||||
if (dataMap != null) { |
||||
redisTemplate.opsForHash().putAll(key, dataMap); |
||||
} |
||||
} |
||||
|
||||
/** |
||||
* 获得缓存的Map |
||||
* |
||||
* @param key |
||||
* @return |
||||
*/ |
||||
public <T> Map<String, T> getCacheMap(final String key) |
||||
{ |
||||
return redisTemplate.opsForHash().entries(key); |
||||
} |
||||
|
||||
/** |
||||
* 往Hash中存入数据 |
||||
* |
||||
* @param key Redis键 |
||||
* @param hKey Hash键 |
||||
* @param value 值 |
||||
*/ |
||||
public <T> void setCacheMapValue(final String key, final String hKey, final T value) |
||||
{ |
||||
redisTemplate.opsForHash().put(key, hKey, value); |
||||
} |
||||
|
||||
/** |
||||
* 获取Hash中的数据 |
||||
* |
||||
* @param key Redis键 |
||||
* @param hKey Hash键 |
||||
* @return Hash中的对象 |
||||
*/ |
||||
public <T> T getCacheMapValue(final String key, final String hKey) |
||||
{ |
||||
HashOperations<String, String, T> opsForHash = redisTemplate.opsForHash(); |
||||
return opsForHash.get(key, hKey); |
||||
} |
||||
|
||||
/** |
||||
* 获取多个Hash中的数据 |
||||
* |
||||
* @param key Redis键 |
||||
* @param hKeys Hash键集合 |
||||
* @return Hash对象集合 |
||||
*/ |
||||
public <T> List<T> getMultiCacheMapValue(final String key, final Collection<Object> hKeys) |
||||
{ |
||||
return redisTemplate.opsForHash().multiGet(key, hKeys); |
||||
} |
||||
|
||||
/** |
||||
* 删除Hash中的某条数据 |
||||
* |
||||
* @param key Redis键 |
||||
* @param hKey Hash键 |
||||
* @return 是否成功 |
||||
*/ |
||||
public boolean deleteCacheMapValue(final String key, final String hKey) |
||||
{ |
||||
return redisTemplate.opsForHash().delete(key, hKey) > 0; |
||||
} |
||||
|
||||
/** |
||||
* 获得缓存的基本对象列表 |
||||
* |
||||
* @param pattern 字符串前缀 |
||||
* @return 对象列表 |
||||
*/ |
||||
public Collection<String> keys(final String pattern) |
||||
{ |
||||
return redisTemplate.keys(pattern); |
||||
} |
||||
} |
@ -0,0 +1,21 @@ |
||||
package com.cjy.traceability.module.traceability.controller.admin.ai.vo; |
||||
|
||||
import lombok.Data; |
||||
import lombok.ToString; |
||||
|
||||
/** |
||||
* @ClassName: LargeModeVO |
||||
* @Author: zc |
||||
* @Date: 2024/4/22 17:22 |
||||
* @Description: TODO |
||||
*/ |
||||
@Data |
||||
@ToString(callSuper = true) |
||||
public class LargeModeVO { |
||||
|
||||
private String message; |
||||
|
||||
private String dialogId; |
||||
|
||||
private String identity; |
||||
} |
Loading…
Reference in new issue