feat: llm基本功能已实现

This commit is contained in:
2025-05-05 01:14:39 +08:00
parent bef6849e45
commit 077166e1a6
21 changed files with 701 additions and 44 deletions

View File

@@ -5,4 +5,10 @@ public interface BaseConstant {
Integer DEFAULT_VERSION = 1;
String TEXT = "text";
String FINISH_REASON = "finishReason";
String USAGE = "usage";
}

View File

@@ -79,6 +79,7 @@ public class RunningContext {
return parser.parseExpression(key).getValue(context);
}
// try {
// // 解析 key 中的数字部分并转换为 Long 类型
// String[] parts = key.split("\\.");
@@ -93,8 +94,10 @@ public class RunningContext {
// log.error("数字类型获取动态参数失败: {}", key);
// }
ExpressionParser parser = new SpelExpressionParser();
StandardEvaluationContext context = new StandardEvaluationContext(this);
return parser.parseExpression(key).getValue(context);
StandardEvaluationContext context = new StandardEvaluationContext();
Map<String, Object> variables = new HashMap<>(this.nodeRunningContext);
context.setVariables(variables);
return parser.parseExpression(convertDotToSquareBrackets(key)).getValue(context);
}
@@ -134,5 +137,8 @@ public class RunningContext {
.build();
}
public String convertDotToSquareBrackets(String key) {
return key.replaceAll("(\\w+)\\.(\\w+)", "#$1['$2']");
}
}

View File

@@ -1,10 +1,28 @@
package com.metis.domain.entity.config.node;
import com.metis.domain.entity.base.NodeConfig;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.List;
@Data
@EqualsAndHashCode(callSuper = true)
public class EndNodeConfig extends NodeConfig {
@Valid
private List<Variable> variables;
@Data
public static class Variable {
@NotBlank(message = "参数字段不能为空")
private String variable;
@NotBlank(message = "参数key不能为空")
private String variableKey;
}
}

View File

@@ -95,7 +95,7 @@ public class AppFlowEngineRunnerServiceImpl implements AppFlowEngineRunnerServic
return RunnerResult.builder()
.result(endRunningContext)
.context(sysContext)
// .context(sysContext)
.build();
}

View File

@@ -0,0 +1,49 @@
package com.metis.llm.domain;
import cn.hutool.core.util.ObjectUtil;
import dev.langchain4j.model.openai.OpenAiTokenUsage;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Data;
@Data
public class Usage {
private Integer inputTokenCount;
private Integer outputTokenCount;
private Integer totalTokenCount;
private InputTokensDetails inputTokensDetails;
private OutputTokensDetails outputTokensDetails;
@Data
private static class InputTokensDetails {
private Integer cachedTokens;
}
@Data
private static class OutputTokensDetails {
private Integer reasoningTokens;
}
public static Usage buildTokenUsage(TokenUsage tokenUsage) {
Usage usage = new Usage();
usage.setInputTokenCount(tokenUsage.inputTokenCount());
usage.setOutputTokenCount(tokenUsage.outputTokenCount());
usage.setTotalTokenCount(tokenUsage.totalTokenCount());
if (tokenUsage instanceof OpenAiTokenUsage openAiTokenUsage) {
if (ObjectUtil.isNotNull(openAiTokenUsage.inputTokensDetails())) {
InputTokensDetails inputTokensDetails = new InputTokensDetails();
Integer cachedTokens = openAiTokenUsage.inputTokensDetails().cachedTokens();
inputTokensDetails.setCachedTokens(cachedTokens);
usage.setInputTokensDetails(inputTokensDetails);
}
if (ObjectUtil.isNotNull(openAiTokenUsage.outputTokensDetails())) {
OutputTokensDetails outputTokensDetails = new OutputTokensDetails();
Integer reasoningTokens = openAiTokenUsage.outputTokensDetails().reasoningTokens();
outputTokensDetails.setReasoningTokens(reasoningTokens);
usage.setOutputTokensDetails(outputTokensDetails);
}
}
return usage;
}
}

View File

@@ -2,7 +2,7 @@ package com.metis.llm.engine;
import com.metis.domain.entity.base.Model;
import com.metis.enums.ModelTypeEnum;
import com.metis.llm.ModelEngine;
import com.metis.llm.service.ModelEngine;
import com.metis.llm.domain.CompletionParams;
import com.metis.llm.domain.LLMChatModeConfig;
import com.metis.llm.domain.LLMEmbeddingModelConfig;

View File

@@ -2,8 +2,8 @@ package com.metis.llm.factory;
import cn.hutool.core.lang.Assert;
import com.metis.enums.ModelTypeEnum;
import com.metis.llm.CustomModelEngine;
import com.metis.llm.ModelEngine;
import com.metis.llm.service.CustomModelEngine;
import com.metis.llm.service.ModelEngine;
import com.metis.llm.domain.config.BaseModelConfig;
import java.util.ArrayList;

View File

@@ -3,8 +3,8 @@ package com.metis.llm.factory;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjectUtil;
import com.metis.enums.ModelTypeEnum;
import com.metis.llm.CustomModelEngine;
import com.metis.llm.ModelEngine;
import com.metis.llm.service.CustomModelEngine;
import com.metis.llm.service.ModelEngine;
import com.metis.llm.domain.config.BaseModelConfig;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;

View File

@@ -0,0 +1,12 @@
package com.metis.llm.service;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.service.Result;
import java.util.List;
public interface ChatBoot {
Result<String> chat(List<ChatMessage> chatMessageList);
}

View File

@@ -1,4 +1,4 @@
package com.metis.llm;
package com.metis.llm.service;
import com.metis.enums.ModelTypeEnum;
import com.metis.llm.domain.config.BaseModelConfig;

View File

@@ -0,0 +1,528 @@
package com.metis.llm.service;
import dev.langchain4j.Internal;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.*;
import dev.langchain4j.service.memory.ChatMemoryAccess;
import dev.langchain4j.service.memory.ChatMemoryService;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolServiceContext;
import dev.langchain4j.service.tool.ToolServiceResult;
import dev.langchain4j.spi.services.AiServicesFactory;
import dev.langchain4j.spi.services.TokenStreamAdapter;
import java.io.InputStream;
import java.lang.reflect.*;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON;
import static dev.langchain4j.service.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.service.TypeUtils.typeHasRawClass;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
@Internal
public class FlowAiServices<T> extends AiServices<T> {
private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();
private final Collection<TokenStreamAdapter> tokenStreamAdapters = loadFactories(TokenStreamAdapter.class);
FlowAiServices(AiServiceContext context) {
super(context);
}
/**
* Begins the construction of an AI Service.
*
* @param aiService The class of the interface to be implemented.
* @return builder
*/
public static <T> AiServices<T> builder(Class<T> aiService) {
AiServiceContext context = new AiServiceContext(aiService);
for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
return factory.create(context);
}
return new FlowAiServices<>(context);
}
static void validateParameters(Method method) {
Parameter[] parameters = method.getParameters();
if (parameters == null || parameters.length < 2) {
return;
}
for (Parameter parameter : parameters) {
V v = parameter.getAnnotation(V.class);
dev.langchain4j.service.UserMessage userMessage =
parameter.getAnnotation(dev.langchain4j.service.UserMessage.class);
MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
UserName userName = parameter.getAnnotation(UserName.class);
if (v == null && userMessage == null && memoryId == null && userName == null) {
throw illegalConfiguration(
"Parameter '%s' of method '%s' should be annotated with @V or @UserMessage "
+ "or @UserName or @MemoryId",
parameter.getName(), method.getName());
}
}
}
public T build() {
performBasicValidation();
if (!context.hasChatMemory() && ChatMemoryAccess.class.isAssignableFrom(context.aiServiceClass)) {
throw illegalConfiguration(
"In order to have a service implementing ChatMemoryAccess, please configure the ChatMemoryProvider on the '%s'.",
context.aiServiceClass.getName());
}
for (Method method : context.aiServiceClass.getMethods()) {
if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
throw illegalConfiguration(
"The @Moderate annotation is present, but the moderationModel is not set up. "
+ "Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
}
Class<?> returnType = method.getReturnType();
if (returnType == void.class) {
throw illegalConfiguration("'%s' is not a supported return type of an AI Service method", returnType.getName());
}
if (returnType == Result.class || returnType == List.class || returnType == Set.class) {
TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
}
if (!context.hasChatMemory()) {
for (Parameter parameter : method.getParameters()) {
if (parameter.isAnnotationPresent(MemoryId.class)) {
throw illegalConfiguration(
"In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.",
context.aiServiceClass.getName());
}
}
}
}
Object proxyInstance = Proxy.newProxyInstance(
context.aiServiceClass.getClassLoader(),
new Class<?>[]{context.aiServiceClass},
new InvocationHandler() {
private final ExecutorService executor = Executors.newCachedThreadPool();
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
if (method.getDeclaringClass() == Object.class) {
// methods like equals(), hashCode() and toString() should not be handled by this proxy
return method.invoke(this, args);
}
if (method.getDeclaringClass() == ChatMemoryAccess.class) {
return switch (method.getName()) {
case "getChatMemory" -> context.chatMemoryService.getChatMemory(args[0]);
case "evictChatMemory" -> context.chatMemoryService.evictChatMemory(args[0]) != null;
default -> throw new UnsupportedOperationException(
"Unknown method on ChatMemoryAccess class : " + method.getName());
};
}
validateParameters(method);
final Object memoryId = findMemoryId(method, args).orElse(ChatMemoryService.DEFAULT);
final ChatMemory chatMemory = context.hasChatMemory()
? context.chatMemoryService.getOrCreateChatMemory(memoryId)
: null;
Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(method, args);
AugmentationResult augmentationResult = null;
if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemoryMessages = chatMemory != null ? chatMemory.messages() : null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemoryMessages);
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
userMessage = (UserMessage) augmentationResult.chatMessage();
}
Type returnType = method.getGenericReturnType();
boolean streaming = returnType == TokenStream.class || canAdaptTokenStreamTo(returnType);
boolean supportsJsonSchema = supportsJsonSchema();
Optional<JsonSchema> jsonSchema = Optional.empty();
if (supportsJsonSchema && !streaming) {
jsonSchema = serviceOutputParser.jsonSchema(returnType);
}
if ((!supportsJsonSchema || jsonSchema.isEmpty()) && !streaming) {
userMessage = appendOutputFormatInstructions(returnType, userMessage);
}
List<ChatMessage> messages;
if (chatMemory != null) {
systemMessage.ifPresent(chatMemory::add);
chatMemory.add(userMessage);
messages = chatMemory.messages();
} else {
messages = new ArrayList<>();
systemMessage.ifPresent(messages::add);
messages.add(userMessage);
}
Future<Moderation> moderationFuture = triggerModerationIfNeeded(method, messages);
ToolServiceContext toolServiceContext =
context.toolService.createContext(memoryId, userMessage);
if (streaming) {
TokenStream tokenStream = new AiServiceTokenStream(AiServiceTokenStreamParameters.builder()
.messages(messages)
.toolSpecifications(toolServiceContext.toolSpecifications())
.toolExecutors(toolServiceContext.toolExecutors())
.retrievedContents(
augmentationResult != null ? augmentationResult.contents() : null)
.context(context)
.memoryId(memoryId)
.build());
// TODO moderation
if (returnType == TokenStream.class) {
return tokenStream;
} else {
return adapt(tokenStream, returnType);
}
}
ResponseFormat responseFormat = null;
if (supportsJsonSchema && jsonSchema.isPresent()) {
responseFormat = ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema.get())
.build();
}
ChatRequestParameters parameters = ChatRequestParameters.builder()
.toolSpecifications(toolServiceContext.toolSpecifications())
.responseFormat(responseFormat)
.build();
ChatRequest chatRequest = ChatRequest.builder()
.messages(messages)
.parameters(parameters)
.build();
ChatResponse chatResponse = context.chatModel.chat(chatRequest);
verifyModerationIfNeeded(moderationFuture);
ToolServiceResult toolServiceResult = context.toolService.executeInferenceAndToolsLoop(
chatResponse,
parameters,
messages,
context.chatModel,
chatMemory,
memoryId,
toolServiceContext.toolExecutors());
chatResponse = toolServiceResult.chatResponse();
Object parsedResponse = serviceOutputParser.parse(chatResponse, returnType);
if (typeHasRawClass(returnType, Result.class)) {
return Result.builder()
.content(parsedResponse)
.tokenUsage(chatResponse.tokenUsage())
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(chatResponse.finishReason())
.toolExecutions(toolServiceResult.toolExecutions())
.build();
} else {
return parsedResponse;
}
}
private boolean canAdaptTokenStreamTo(Type returnType) {
for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
return true;
}
}
return false;
}
private Object adapt(TokenStream tokenStream, Type returnType) {
for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
return tokenStreamAdapter.adapt(tokenStream);
}
}
throw new IllegalStateException("Can't find suitable TokenStreamAdapter");
}
private boolean supportsJsonSchema() {
return context.chatModel != null
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
}
private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) {
String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
String text = userMessage.singleText() + outputFormatInstructions;
if (isNotNullOrBlank(userMessage.name())) {
userMessage = UserMessage.from(userMessage.name(), text);
} else {
userMessage = UserMessage.from(text);
}
return userMessage;
}
private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
if (method.isAnnotationPresent(Moderate.class)) {
return executor.submit(() -> {
List<ChatMessage> messagesToModerate = removeToolMessages(messages);
return context.moderationModel
.moderate(messagesToModerate)
.content();
});
}
return null;
}
});
return (T) proxyInstance;
}
private Optional<SystemMessage> prepareSystemMessage(Object memoryId, Method method, Object[] args) {
return findSystemMessageTemplate(memoryId, method)
.map(systemMessageTemplate -> PromptTemplate.from(systemMessageTemplate)
.apply(findTemplateVariables(systemMessageTemplate, method, args))
.toSystemMessage());
}
private Optional<String> findSystemMessageTemplate(Object memoryId, Method method) {
dev.langchain4j.service.SystemMessage annotation =
method.getAnnotation(dev.langchain4j.service.SystemMessage.class);
if (annotation != null) {
return Optional.of(getTemplate(
method, "System", annotation.fromResource(), annotation.value(), annotation.delimiter()));
}
return context.systemMessageProvider.apply(memoryId);
}
private static Map<String, Object> findTemplateVariables(String template, Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
Map<String, Object> variables = new HashMap<>();
for (int i = 0; i < parameters.length; i++) {
String variableName = getVariableName(parameters[i]);
Object variableValue = args[i];
variables.put(variableName, variableValue);
}
if (template.contains("{{it}}") && !variables.containsKey("it")) {
String itValue = getValueOfVariableIt(parameters, args);
variables.put("it", itValue);
}
return variables;
}
private static String getVariableName(Parameter parameter) {
V annotation = parameter.getAnnotation(V.class);
if (annotation != null) {
return annotation.value();
} else {
return parameter.getName();
}
}
private static String getValueOfVariableIt(Parameter[] parameters, Object[] args) {
if (parameters.length == 1) {
Parameter parameter = parameters[0];
if (!parameter.isAnnotationPresent(MemoryId.class)
&& !parameter.isAnnotationPresent(dev.langchain4j.service.UserMessage.class)
&& !parameter.isAnnotationPresent(UserName.class)
&& (!parameter.isAnnotationPresent(V.class) || isAnnotatedWithIt(parameter))) {
return toString(args[0]);
}
}
for (int i = 0; i < parameters.length; i++) {
if (isAnnotatedWithIt(parameters[i])) {
return toString(args[i]);
}
}
throw illegalConfiguration("Error: cannot find the value of the prompt template variable \"{{it}}\".");
}
private static boolean isAnnotatedWithIt(Parameter parameter) {
V annotation = parameter.getAnnotation(V.class);
return annotation != null && "it".equals(annotation.value());
}
private static UserMessage prepareUserMessage(Method method, Object[] args) {
String template = getUserMessageTemplate(method, args);
Map<String, Object> variables = findTemplateVariables(template, method, args);
Prompt prompt = PromptTemplate.from(template).apply(variables);
Optional<String> maybeUserName = findUserName(method.getParameters(), args);
return maybeUserName
.map(userName -> UserMessage.from(userName, prompt.text()))
.orElseGet(prompt::toUserMessage);
}
private static String getUserMessageTemplate(Method method, Object[] args) {
Optional<String> templateFromMethodAnnotation = findUserMessageTemplateFromMethodAnnotation(method);
Optional<String> templateFromParameterAnnotation =
findUserMessageTemplateFromAnnotatedParameter(method.getParameters(), args);
if (templateFromMethodAnnotation.isPresent() && templateFromParameterAnnotation.isPresent()) {
throw illegalConfiguration(
"Error: The method '%s' has multiple @UserMessage annotations. Please use only one.",
method.getName());
}
if (templateFromMethodAnnotation.isPresent()) {
return templateFromMethodAnnotation.get();
}
if (templateFromParameterAnnotation.isPresent()) {
return templateFromParameterAnnotation.get();
}
Optional<String> templateFromTheOnlyArgument =
findUserMessageTemplateFromTheOnlyArgument(method.getParameters(), args);
if (templateFromTheOnlyArgument.isPresent()) {
return templateFromTheOnlyArgument.get();
}
throw illegalConfiguration("Error: The method '%s' does not have a user message defined.", method.getName());
}
private static Optional<String> findUserMessageTemplateFromMethodAnnotation(Method method) {
return Optional.ofNullable(method.getAnnotation(dev.langchain4j.service.UserMessage.class))
.map(a -> getTemplate(method, "User", a.fromResource(), a.value(), a.delimiter()));
}
private static Optional<String> findUserMessageTemplateFromAnnotatedParameter(
Parameter[] parameters, Object[] args) {
for (int i = 0; i < parameters.length; i++) {
if (parameters[i].isAnnotationPresent(dev.langchain4j.service.UserMessage.class)) {
return Optional.of(toString(args[i]));
}
}
return Optional.empty();
}
private static Optional<String> findUserMessageTemplateFromTheOnlyArgument(Parameter[] parameters, Object[] args) {
if (parameters != null && parameters.length == 1 && parameters[0].getAnnotations().length == 0) {
return Optional.of(toString(args[0]));
}
return Optional.empty();
}
private static Optional<String> findUserName(Parameter[] parameters, Object[] args) {
for (int i = 0; i < parameters.length; i++) {
if (parameters[i].isAnnotationPresent(UserName.class)) {
return Optional.of(args[i].toString());
}
}
return Optional.empty();
}
private static String getTemplate(Method method, String type, String resource, String[] value, String delimiter) {
String messageTemplate;
if (!resource.trim().isEmpty()) {
messageTemplate = getResourceText(method.getDeclaringClass(), resource);
if (messageTemplate == null) {
throw illegalConfiguration("@%sMessage's resource '%s' not found", type, resource);
}
} else {
messageTemplate = String.join(delimiter, value);
}
if (messageTemplate.trim().isEmpty()) {
throw illegalConfiguration("@%sMessage's template cannot be empty", type);
}
return messageTemplate;
}
private static String getResourceText(Class<?> clazz, String resource) {
InputStream inputStream = clazz.getResourceAsStream(resource);
if (inputStream == null) {
inputStream = clazz.getResourceAsStream("/" + resource);
}
return getText(inputStream);
}
private static String getText(InputStream inputStream) {
if (inputStream == null) {
return null;
}
try (Scanner scanner = new Scanner(inputStream);
Scanner s = scanner.useDelimiter("\\A")) {
return s.hasNext() ? s.next() : "";
}
}
private static Optional<Object> findMemoryId(Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
for (int i = 0; i < parameters.length; i++) {
if (parameters[i].isAnnotationPresent(MemoryId.class)) {
Object memoryId = args[i];
if (memoryId == null) {
throw illegalArgument(
"The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null",
parameters[i].getName(), method.getName());
}
return Optional.of(memoryId);
}
}
return Optional.empty();
}
private static String toString(Object arg) {
if (arg.getClass().isArray()) {
return arrayToString(arg);
} else if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
return StructuredPromptProcessor.toPrompt(arg).text();
} else {
return arg.toString();
}
}
private static String arrayToString(Object arg) {
StringBuilder sb = new StringBuilder("[");
int length = Array.getLength(arg);
for (int i = 0; i < length; i++) {
sb.append(toString(Array.get(arg, i)));
if (i < length - 1) {
sb.append(", ");
}
}
sb.append("]");
return sb.toString();
}
}

View File

@@ -1,4 +1,4 @@
package com.metis.llm;
package com.metis.llm.service;
import com.metis.domain.entity.base.Model;

View File

@@ -1,4 +1,4 @@
package com.metis.llm;
package com.metis.llm.service;
import com.metis.llm.domain.LLMChatModeConfig;
import dev.langchain4j.model.chat.ChatModel;

View File

@@ -1,4 +1,4 @@
package com.metis.llm;
package com.metis.llm.service;
import com.metis.domain.entity.base.Model;

View File

@@ -1,12 +1,12 @@
package com.metis.llm.impl;
package com.metis.llm.service.impl;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjectUtil;
import com.metis.domain.entity.base.Model;
import com.metis.enums.ModelTypeEnum;
import com.metis.llm.ModelEngine;
import com.metis.llm.ModelEngineService;
import com.metis.llm.ModelService;
import com.metis.llm.service.ModelEngine;
import com.metis.llm.service.ModelEngineService;
import com.metis.llm.service.ModelService;
import com.metis.llm.domain.LLMChatModeConfig;
import com.metis.llm.domain.LLMEmbeddingModelConfig;
import com.metis.llm.factory.ModelEngineFactory;

View File

@@ -1,11 +1,11 @@
package com.metis.llm.impl;
package com.metis.llm.service.impl;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjectUtil;
import com.metis.convert.ModelPlatformConvert;
import com.metis.domain.entity.ModelPlatform;
import com.metis.domain.entity.base.Model;
import com.metis.llm.ModelService;
import com.metis.llm.service.ModelService;
import com.metis.service.ModelPlatformService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

View File

@@ -2,12 +2,9 @@ package com.metis.runner;
import com.alibaba.fastjson2.JSONObject;
import com.metis.domain.context.SysContext;
import lombok.Builder;
import lombok.Data;
import java.util.Map;
/**
* 运行结果
*
@@ -26,7 +23,7 @@ public class RunnerResult {
/**
* 上下文
*/
private SysContext context;
// private SysContext context;
}

View File

@@ -1,6 +1,7 @@
package com.metis.runner.impl;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson2.JSONObject;
import com.metis.domain.context.RunningContext;
import com.metis.domain.context.RunningResult;
@@ -21,7 +22,15 @@ public class EndNodeRunner implements NodeRunner<EndNodeConfig> {
@Override
public RunningResult run(RunningContext context, Node node, List<Edge> edges) {
JSONObject contextNodeValue = new JSONObject();
contextNodeValue.put("userId", context.getSys().getAppId());
EndNodeConfig config = node.getConfig();
List<EndNodeConfig.Variable> variables = config.getVariables();
if (CollUtil.isEmpty(variables)){
return RunningResult.buildResult();
}
for (EndNodeConfig.Variable variable : variables) {
Object value = context.getValue(variable.getVariableKey());
contextNodeValue.put(variable.getVariable(), value);
}
return RunningResult.buildResult(contextNodeValue);
}

View File

@@ -39,8 +39,8 @@ public class StartNodeRunner implements NodeRunner<StartNodeConfig> {
JSONObject custom = context.getCustom();
JSONObject contextNodeValue = new JSONObject();
// 获取到系统上下文, 并将系统上下文放入到start的运行结果中, 用于后续调用
JSONObject sysContext = getSysContext(context);
contextNodeValue.putAll(sysContext);
// JSONObject sysContext = getSysContext(context);
// contextNodeValue.putAll(sysContext);
for (NodeVariable variable : variables) {
Object value = variable.getValue(custom);

View File

@@ -1,4 +1,4 @@
package com.metis.runner.impl;
package com.metis.runner.impl.llm;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
@@ -11,8 +11,11 @@ import com.metis.domain.entity.base.Node;
import com.metis.domain.entity.config.node.llm.LLMNodeConfig;
import com.metis.domain.entity.config.node.llm.PromptTemplate;
import com.metis.enums.NodeType;
import com.metis.llm.ModelEngineService;
import com.metis.llm.domain.LLMChatModeConfig;
import com.metis.llm.domain.Usage;
import com.metis.llm.service.ChatBoot;
import com.metis.llm.service.FlowAiServices;
import com.metis.llm.service.ModelEngineService;
import com.metis.runner.NodeRunner;
import com.metis.template.domain.RenderContext;
import com.metis.template.utils.VelocityUtil;
@@ -21,7 +24,7 @@ import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.service.Result;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@@ -45,12 +48,20 @@ public class LLMNodeRunner implements NodeRunner<LLMNodeConfig> {
ChatModel chatModel = modelEngineService.getChatLanguageModel(model);
List<ChatMessage> chatMessageList = buildChatMessage(context, config);
ChatResponse chat = chatModel.chat(chatMessageList);
String string = chat.toString();
log.info("LLM 输出结果: {}", string);
// todo 需要使用FlowAiServices单独扩展build方法中返回数据格式的操作, 可以用在二期or三期进行扩展开发
ChatBoot chatBoot = FlowAiServices.builder(ChatBoot.class)
.chatModel(chatModel)
.build();
Result<String> chatResult = chatBoot.chat(chatMessageList);
String text = chatResult.content();
log.info("LLM 输出结果: {}", text);
JSONObject result = new JSONObject();
result.put(BaseConstant.TEXT, string);
result.put(BaseConstant.TEXT, text);
result.put(BaseConstant.USAGE, Usage.buildTokenUsage(chatResult.tokenUsage()));
result.put(BaseConstant.FINISH_REASON, chatResult.finishReason());
return RunningResult.buildResult(result);
}

View File

@@ -1,4 +1,3 @@
{
"appId": 1919041086810968064,
"name": "llm运行测试",
@@ -6,7 +5,7 @@
"graph": {
"nodes": [
{
"id": "5",
"id": "node_5",
"type": "start",
"initialized": false,
"position": {
@@ -49,7 +48,7 @@
"height": 40
},
{
"id": "700",
"id": "node_700",
"type": "llm",
"initialized": false,
"position": {
@@ -61,7 +60,7 @@
"icon": "",
"toolbarPosition": "right",
"config": {
"context": "5.background",
"context": "node_5.background",
"retryConfig": {
"enable": true,
"maxRetries": 3,
@@ -75,7 +74,7 @@
},
{
"role": "user",
"text": "请你解释一下上述问题${5.query}",
"text": "请你解释一下上述问题${node_5.query}",
"id": "2"
}
],
@@ -112,7 +111,7 @@
"height": 40
},
{
"id": "802",
"id": "node_802",
"type": "end",
"initialized": false,
"position": {
@@ -123,7 +122,30 @@
"label": "结束",
"icon": "",
"toolbarPosition": "right",
"config": {},
"config": {
"variables": [
{
"variable": "query",
"variableKey": "node_5.query"
},
{
"variable": "background",
"variableKey": "node_5.background"
},
{
"variable": "usage",
"variableKey": "node_700.usage"
},
{
"variable": "finishReason",
"variableKey": "node_700.finishReason"
},
{
"variable": "text",
"variableKey": "node_700.text"
}
]
},
"handles": [
{
"id": "13",
@@ -142,8 +164,8 @@
{
"id": "vueflow__edge-551-70057",
"type": "default",
"source": "5",
"target": "700",
"source": "node_5",
"target": "node_700",
"sourceHandle": "51",
"targetHandle": "57",
"data": {},
@@ -159,8 +181,8 @@
{
"id": "802",
"type": "default",
"source": "700",
"target": "802",
"source": "node_700",
"target": "node_802",
"sourceHandle": "35",
"targetHandle": "13",
"data": {},
@@ -185,5 +207,4 @@
"zoom": 3.5988075528449266
}
}
}