diff --git a/metis-starter/src/main/java/com/metis/constant/BaseConstant.java b/metis-starter/src/main/java/com/metis/constant/BaseConstant.java index a6781f3..71a75d9 100644 --- a/metis-starter/src/main/java/com/metis/constant/BaseConstant.java +++ b/metis-starter/src/main/java/com/metis/constant/BaseConstant.java @@ -5,4 +5,10 @@ public interface BaseConstant { Integer DEFAULT_VERSION = 1; String TEXT = "text"; + + String FINISH_REASON = "finishReason"; + + String USAGE = "usage"; + + } diff --git a/metis-starter/src/main/java/com/metis/domain/context/RunningContext.java b/metis-starter/src/main/java/com/metis/domain/context/RunningContext.java index 117fd4a..6eda5f6 100644 --- a/metis-starter/src/main/java/com/metis/domain/context/RunningContext.java +++ b/metis-starter/src/main/java/com/metis/domain/context/RunningContext.java @@ -79,6 +79,7 @@ public class RunningContext { return parser.parseExpression(key).getValue(context); } + // try { // // 解析 key 中的数字部分并转换为 Long 类型 // String[] parts = key.split("\\."); @@ -93,8 +94,10 @@ public class RunningContext { // log.error("数字类型获取动态参数失败: {}", key); // } ExpressionParser parser = new SpelExpressionParser(); - StandardEvaluationContext context = new StandardEvaluationContext(this); - return parser.parseExpression(key).getValue(context); + StandardEvaluationContext context = new StandardEvaluationContext(); + Map variables = new HashMap<>(this.nodeRunningContext); + context.setVariables(variables); + return parser.parseExpression(convertDotToSquareBrackets(key)).getValue(context); } @@ -134,5 +137,8 @@ public class RunningContext { .build(); } + public String convertDotToSquareBrackets(String key) { + return key.replaceAll("(\\w+)\\.(\\w+)", "#$1['$2']"); + } } diff --git a/metis-starter/src/main/java/com/metis/domain/entity/config/node/EndNodeConfig.java b/metis-starter/src/main/java/com/metis/domain/entity/config/node/EndNodeConfig.java index 202e601..d256d95 100644 --- a/metis-starter/src/main/java/com/metis/domain/entity/config/node/EndNodeConfig.java +++ b/metis-starter/src/main/java/com/metis/domain/entity/config/node/EndNodeConfig.java @@ -1,10 +1,28 @@ package com.metis.domain.entity.config.node; import com.metis.domain.entity.base.NodeConfig; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotBlank; import lombok.Data; import lombok.EqualsAndHashCode; +import java.util.List; + @Data @EqualsAndHashCode(callSuper = true) public class EndNodeConfig extends NodeConfig { + + + @Valid + private List variables; + + + @Data + public static class Variable { + @NotBlank(message = "参数字段不能为空") + private String variable; + @NotBlank(message = "参数key不能为空") + private String variableKey; + + } } diff --git a/metis-starter/src/main/java/com/metis/engine/impl/AppFlowEngineRunnerServiceImpl.java b/metis-starter/src/main/java/com/metis/engine/impl/AppFlowEngineRunnerServiceImpl.java index 05e7aab..8b6b503 100644 --- a/metis-starter/src/main/java/com/metis/engine/impl/AppFlowEngineRunnerServiceImpl.java +++ b/metis-starter/src/main/java/com/metis/engine/impl/AppFlowEngineRunnerServiceImpl.java @@ -95,7 +95,7 @@ public class AppFlowEngineRunnerServiceImpl implements AppFlowEngineRunnerServic return RunnerResult.builder() .result(endRunningContext) - .context(sysContext) +// .context(sysContext) .build(); } diff --git a/metis-starter/src/main/java/com/metis/llm/domain/Usage.java b/metis-starter/src/main/java/com/metis/llm/domain/Usage.java new file mode 100644 index 0000000..8e46dca --- /dev/null +++ b/metis-starter/src/main/java/com/metis/llm/domain/Usage.java @@ -0,0 +1,49 @@ +package com.metis.llm.domain; + +import cn.hutool.core.util.ObjectUtil; +import dev.langchain4j.model.openai.OpenAiTokenUsage; +import dev.langchain4j.model.output.TokenUsage; +import lombok.Data; + +@Data +public class Usage { + private Integer inputTokenCount; + private Integer outputTokenCount; + private Integer totalTokenCount; + private InputTokensDetails inputTokensDetails; + private OutputTokensDetails outputTokensDetails; + + @Data + private static class InputTokensDetails { + private Integer cachedTokens; + } + + @Data + private static class OutputTokensDetails { + private Integer reasoningTokens; + } + + + public static Usage buildTokenUsage(TokenUsage tokenUsage) { + Usage usage = new Usage(); + usage.setInputTokenCount(tokenUsage.inputTokenCount()); + usage.setOutputTokenCount(tokenUsage.outputTokenCount()); + usage.setTotalTokenCount(tokenUsage.totalTokenCount()); + if (tokenUsage instanceof OpenAiTokenUsage openAiTokenUsage) { + if (ObjectUtil.isNotNull(openAiTokenUsage.inputTokensDetails())) { + InputTokensDetails inputTokensDetails = new InputTokensDetails(); + Integer cachedTokens = openAiTokenUsage.inputTokensDetails().cachedTokens(); + inputTokensDetails.setCachedTokens(cachedTokens); + usage.setInputTokensDetails(inputTokensDetails); + } + if (ObjectUtil.isNotNull(openAiTokenUsage.outputTokensDetails())) { + OutputTokensDetails outputTokensDetails = new OutputTokensDetails(); + Integer reasoningTokens = openAiTokenUsage.outputTokensDetails().reasoningTokens(); + outputTokensDetails.setReasoningTokens(reasoningTokens); + usage.setOutputTokensDetails(outputTokensDetails); + } + } + + return usage; + } +} diff --git a/metis-starter/src/main/java/com/metis/llm/engine/OpenApiModelEngine.java b/metis-starter/src/main/java/com/metis/llm/engine/OpenApiModelEngine.java index ad2b494..d3d5247 100644 --- a/metis-starter/src/main/java/com/metis/llm/engine/OpenApiModelEngine.java +++ b/metis-starter/src/main/java/com/metis/llm/engine/OpenApiModelEngine.java @@ -2,7 +2,7 @@ package com.metis.llm.engine; import com.metis.domain.entity.base.Model; import com.metis.enums.ModelTypeEnum; -import com.metis.llm.ModelEngine; +import com.metis.llm.service.ModelEngine; import com.metis.llm.domain.CompletionParams; import com.metis.llm.domain.LLMChatModeConfig; import com.metis.llm.domain.LLMEmbeddingModelConfig; diff --git a/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineFactory.java b/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineFactory.java index e97dc26..15aadfd 100644 --- a/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineFactory.java +++ b/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineFactory.java @@ -2,8 +2,8 @@ package com.metis.llm.factory; import cn.hutool.core.lang.Assert; import com.metis.enums.ModelTypeEnum; -import com.metis.llm.CustomModelEngine; -import com.metis.llm.ModelEngine; +import com.metis.llm.service.CustomModelEngine; +import com.metis.llm.service.ModelEngine; import com.metis.llm.domain.config.BaseModelConfig; import java.util.ArrayList; diff --git a/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineInitiate.java b/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineInitiate.java index 47b44ba..2f664d7 100644 --- a/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineInitiate.java +++ b/metis-starter/src/main/java/com/metis/llm/factory/ModelEngineInitiate.java @@ -3,8 +3,8 @@ package com.metis.llm.factory; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.ObjectUtil; import com.metis.enums.ModelTypeEnum; -import com.metis.llm.CustomModelEngine; -import com.metis.llm.ModelEngine; +import com.metis.llm.service.CustomModelEngine; +import com.metis.llm.service.ModelEngine; import com.metis.llm.domain.config.BaseModelConfig; import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; diff --git a/metis-starter/src/main/java/com/metis/llm/service/ChatBoot.java b/metis-starter/src/main/java/com/metis/llm/service/ChatBoot.java new file mode 100644 index 0000000..c725c60 --- /dev/null +++ b/metis-starter/src/main/java/com/metis/llm/service/ChatBoot.java @@ -0,0 +1,12 @@ +package com.metis.llm.service; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.service.Result; + +import java.util.List; + +public interface ChatBoot { + + Result chat(List chatMessageList); + +} diff --git a/metis-starter/src/main/java/com/metis/llm/CustomModelEngine.java b/metis-starter/src/main/java/com/metis/llm/service/CustomModelEngine.java similarity index 90% rename from metis-starter/src/main/java/com/metis/llm/CustomModelEngine.java rename to metis-starter/src/main/java/com/metis/llm/service/CustomModelEngine.java index 173bcd4..f58a6ef 100644 --- a/metis-starter/src/main/java/com/metis/llm/CustomModelEngine.java +++ b/metis-starter/src/main/java/com/metis/llm/service/CustomModelEngine.java @@ -1,4 +1,4 @@ -package com.metis.llm; +package com.metis.llm.service; import com.metis.enums.ModelTypeEnum; import com.metis.llm.domain.config.BaseModelConfig; diff --git a/metis-starter/src/main/java/com/metis/llm/service/FlowAiServices.java b/metis-starter/src/main/java/com/metis/llm/service/FlowAiServices.java new file mode 100644 index 0000000..d30f6d4 --- /dev/null +++ b/metis-starter/src/main/java/com/metis/llm/service/FlowAiServices.java @@ -0,0 +1,528 @@ +package com.metis.llm.service; + +import dev.langchain4j.Internal; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ChatRequestParameters; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.input.structured.StructuredPrompt; +import dev.langchain4j.model.input.structured.StructuredPromptProcessor; +import dev.langchain4j.model.moderation.Moderation; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.query.Metadata; +import dev.langchain4j.service.*; +import dev.langchain4j.service.memory.ChatMemoryAccess; +import dev.langchain4j.service.memory.ChatMemoryService; +import dev.langchain4j.service.output.ServiceOutputParser; +import dev.langchain4j.service.tool.ToolServiceContext; +import dev.langchain4j.service.tool.ToolServiceResult; +import dev.langchain4j.spi.services.AiServicesFactory; +import dev.langchain4j.spi.services.TokenStreamAdapter; + +import java.io.InputStream; +import java.lang.reflect.*; +import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static dev.langchain4j.internal.Exceptions.illegalArgument; +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; +import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; +import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON; +import static dev.langchain4j.service.IllegalConfigurationException.illegalConfiguration; +import static dev.langchain4j.service.TypeUtils.typeHasRawClass; +import static dev.langchain4j.spi.ServiceHelper.loadFactories; + +@Internal +public class FlowAiServices extends AiServices { + + + private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser(); + private final Collection tokenStreamAdapters = loadFactories(TokenStreamAdapter.class); + + FlowAiServices(AiServiceContext context) { + super(context); + } + + /** + * Begins the construction of an AI Service. + * + * @param aiService The class of the interface to be implemented. + * @return builder + */ + public static AiServices builder(Class aiService) { + AiServiceContext context = new AiServiceContext(aiService); + for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) { + return factory.create(context); + } + return new FlowAiServices<>(context); + } + + + static void validateParameters(Method method) { + Parameter[] parameters = method.getParameters(); + if (parameters == null || parameters.length < 2) { + return; + } + + for (Parameter parameter : parameters) { + V v = parameter.getAnnotation(V.class); + dev.langchain4j.service.UserMessage userMessage = + parameter.getAnnotation(dev.langchain4j.service.UserMessage.class); + MemoryId memoryId = parameter.getAnnotation(MemoryId.class); + UserName userName = parameter.getAnnotation(UserName.class); + if (v == null && userMessage == null && memoryId == null && userName == null) { + throw illegalConfiguration( + "Parameter '%s' of method '%s' should be annotated with @V or @UserMessage " + + "or @UserName or @MemoryId", + parameter.getName(), method.getName()); + } + } + } + + + + public T build() { + + performBasicValidation(); + + if (!context.hasChatMemory() && ChatMemoryAccess.class.isAssignableFrom(context.aiServiceClass)) { + throw illegalConfiguration( + "In order to have a service implementing ChatMemoryAccess, please configure the ChatMemoryProvider on the '%s'.", + context.aiServiceClass.getName()); + } + + for (Method method : context.aiServiceClass.getMethods()) { + if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) { + throw illegalConfiguration( + "The @Moderate annotation is present, but the moderationModel is not set up. " + + "Please ensure a valid moderationModel is configured before using the @Moderate annotation."); + } + + Class returnType = method.getReturnType(); + if (returnType == void.class) { + throw illegalConfiguration("'%s' is not a supported return type of an AI Service method", returnType.getName()); + } + if (returnType == Result.class || returnType == List.class || returnType == Set.class) { + TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType()); + } + + if (!context.hasChatMemory()) { + for (Parameter parameter : method.getParameters()) { + if (parameter.isAnnotationPresent(MemoryId.class)) { + throw illegalConfiguration( + "In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.", + context.aiServiceClass.getName()); + } + } + } + } + + Object proxyInstance = Proxy.newProxyInstance( + context.aiServiceClass.getClassLoader(), + new Class[]{context.aiServiceClass}, + new InvocationHandler() { + + private final ExecutorService executor = Executors.newCachedThreadPool(); + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Exception { + + if (method.getDeclaringClass() == Object.class) { + // methods like equals(), hashCode() and toString() should not be handled by this proxy + return method.invoke(this, args); + } + + if (method.getDeclaringClass() == ChatMemoryAccess.class) { + return switch (method.getName()) { + case "getChatMemory" -> context.chatMemoryService.getChatMemory(args[0]); + case "evictChatMemory" -> context.chatMemoryService.evictChatMemory(args[0]) != null; + default -> throw new UnsupportedOperationException( + "Unknown method on ChatMemoryAccess class : " + method.getName()); + }; + } + + validateParameters(method); + + final Object memoryId = findMemoryId(method, args).orElse(ChatMemoryService.DEFAULT); + final ChatMemory chatMemory = context.hasChatMemory() + ? context.chatMemoryService.getOrCreateChatMemory(memoryId) + : null; + + Optional systemMessage = prepareSystemMessage(memoryId, method, args); + UserMessage userMessage = prepareUserMessage(method, args); + AugmentationResult augmentationResult = null; + if (context.retrievalAugmentor != null) { + List chatMemoryMessages = chatMemory != null ? chatMemory.messages() : null; + Metadata metadata = Metadata.from(userMessage, memoryId, chatMemoryMessages); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); + userMessage = (UserMessage) augmentationResult.chatMessage(); + } + + Type returnType = method.getGenericReturnType(); + boolean streaming = returnType == TokenStream.class || canAdaptTokenStreamTo(returnType); + boolean supportsJsonSchema = supportsJsonSchema(); + Optional jsonSchema = Optional.empty(); + if (supportsJsonSchema && !streaming) { + jsonSchema = serviceOutputParser.jsonSchema(returnType); + } + if ((!supportsJsonSchema || jsonSchema.isEmpty()) && !streaming) { + userMessage = appendOutputFormatInstructions(returnType, userMessage); + } + + List messages; + if (chatMemory != null) { + systemMessage.ifPresent(chatMemory::add); + chatMemory.add(userMessage); + messages = chatMemory.messages(); + } else { + messages = new ArrayList<>(); + systemMessage.ifPresent(messages::add); + messages.add(userMessage); + } + + Future moderationFuture = triggerModerationIfNeeded(method, messages); + + ToolServiceContext toolServiceContext = + context.toolService.createContext(memoryId, userMessage); + + if (streaming) { + TokenStream tokenStream = new AiServiceTokenStream(AiServiceTokenStreamParameters.builder() + .messages(messages) + .toolSpecifications(toolServiceContext.toolSpecifications()) + .toolExecutors(toolServiceContext.toolExecutors()) + .retrievedContents( + augmentationResult != null ? augmentationResult.contents() : null) + .context(context) + .memoryId(memoryId) + .build()); + // TODO moderation + if (returnType == TokenStream.class) { + return tokenStream; + } else { + return adapt(tokenStream, returnType); + } + } + + ResponseFormat responseFormat = null; + if (supportsJsonSchema && jsonSchema.isPresent()) { + responseFormat = ResponseFormat.builder() + .type(JSON) + .jsonSchema(jsonSchema.get()) + .build(); + } + + ChatRequestParameters parameters = ChatRequestParameters.builder() + .toolSpecifications(toolServiceContext.toolSpecifications()) + .responseFormat(responseFormat) + .build(); + + ChatRequest chatRequest = ChatRequest.builder() + .messages(messages) + .parameters(parameters) + .build(); + + ChatResponse chatResponse = context.chatModel.chat(chatRequest); + + verifyModerationIfNeeded(moderationFuture); + + ToolServiceResult toolServiceResult = context.toolService.executeInferenceAndToolsLoop( + chatResponse, + parameters, + messages, + context.chatModel, + chatMemory, + memoryId, + toolServiceContext.toolExecutors()); + + chatResponse = toolServiceResult.chatResponse(); + + Object parsedResponse = serviceOutputParser.parse(chatResponse, returnType); + if (typeHasRawClass(returnType, Result.class)) { + return Result.builder() + .content(parsedResponse) + .tokenUsage(chatResponse.tokenUsage()) + .sources(augmentationResult == null ? null : augmentationResult.contents()) + .finishReason(chatResponse.finishReason()) + .toolExecutions(toolServiceResult.toolExecutions()) + .build(); + } else { + return parsedResponse; + } + } + + private boolean canAdaptTokenStreamTo(Type returnType) { + for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) { + if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) { + return true; + } + } + return false; + } + + private Object adapt(TokenStream tokenStream, Type returnType) { + for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) { + if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) { + return tokenStreamAdapter.adapt(tokenStream); + } + } + throw new IllegalStateException("Can't find suitable TokenStreamAdapter"); + } + + private boolean supportsJsonSchema() { + return context.chatModel != null + && context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA); + } + + private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) { + String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType); + String text = userMessage.singleText() + outputFormatInstructions; + if (isNotNullOrBlank(userMessage.name())) { + userMessage = UserMessage.from(userMessage.name(), text); + } else { + userMessage = UserMessage.from(text); + } + return userMessage; + } + + private Future triggerModerationIfNeeded(Method method, List messages) { + if (method.isAnnotationPresent(Moderate.class)) { + return executor.submit(() -> { + List messagesToModerate = removeToolMessages(messages); + return context.moderationModel + .moderate(messagesToModerate) + .content(); + }); + } + return null; + } + }); + + return (T) proxyInstance; + } + + private Optional prepareSystemMessage(Object memoryId, Method method, Object[] args) { + return findSystemMessageTemplate(memoryId, method) + .map(systemMessageTemplate -> PromptTemplate.from(systemMessageTemplate) + .apply(findTemplateVariables(systemMessageTemplate, method, args)) + .toSystemMessage()); + } + + private Optional findSystemMessageTemplate(Object memoryId, Method method) { + dev.langchain4j.service.SystemMessage annotation = + method.getAnnotation(dev.langchain4j.service.SystemMessage.class); + if (annotation != null) { + return Optional.of(getTemplate( + method, "System", annotation.fromResource(), annotation.value(), annotation.delimiter())); + } + + return context.systemMessageProvider.apply(memoryId); + } + + private static Map findTemplateVariables(String template, Method method, Object[] args) { + Parameter[] parameters = method.getParameters(); + + Map variables = new HashMap<>(); + for (int i = 0; i < parameters.length; i++) { + String variableName = getVariableName(parameters[i]); + Object variableValue = args[i]; + variables.put(variableName, variableValue); + } + + if (template.contains("{{it}}") && !variables.containsKey("it")) { + String itValue = getValueOfVariableIt(parameters, args); + variables.put("it", itValue); + } + + return variables; + } + + private static String getVariableName(Parameter parameter) { + V annotation = parameter.getAnnotation(V.class); + if (annotation != null) { + return annotation.value(); + } else { + return parameter.getName(); + } + } + + private static String getValueOfVariableIt(Parameter[] parameters, Object[] args) { + if (parameters.length == 1) { + Parameter parameter = parameters[0]; + if (!parameter.isAnnotationPresent(MemoryId.class) + && !parameter.isAnnotationPresent(dev.langchain4j.service.UserMessage.class) + && !parameter.isAnnotationPresent(UserName.class) + && (!parameter.isAnnotationPresent(V.class) || isAnnotatedWithIt(parameter))) { + return toString(args[0]); + } + } + + for (int i = 0; i < parameters.length; i++) { + if (isAnnotatedWithIt(parameters[i])) { + return toString(args[i]); + } + } + + throw illegalConfiguration("Error: cannot find the value of the prompt template variable \"{{it}}\"."); + } + + private static boolean isAnnotatedWithIt(Parameter parameter) { + V annotation = parameter.getAnnotation(V.class); + return annotation != null && "it".equals(annotation.value()); + } + + private static UserMessage prepareUserMessage(Method method, Object[] args) { + + String template = getUserMessageTemplate(method, args); + Map variables = findTemplateVariables(template, method, args); + + Prompt prompt = PromptTemplate.from(template).apply(variables); + + Optional maybeUserName = findUserName(method.getParameters(), args); + return maybeUserName + .map(userName -> UserMessage.from(userName, prompt.text())) + .orElseGet(prompt::toUserMessage); + } + + private static String getUserMessageTemplate(Method method, Object[] args) { + + Optional templateFromMethodAnnotation = findUserMessageTemplateFromMethodAnnotation(method); + Optional templateFromParameterAnnotation = + findUserMessageTemplateFromAnnotatedParameter(method.getParameters(), args); + + if (templateFromMethodAnnotation.isPresent() && templateFromParameterAnnotation.isPresent()) { + throw illegalConfiguration( + "Error: The method '%s' has multiple @UserMessage annotations. Please use only one.", + method.getName()); + } + + if (templateFromMethodAnnotation.isPresent()) { + return templateFromMethodAnnotation.get(); + } + if (templateFromParameterAnnotation.isPresent()) { + return templateFromParameterAnnotation.get(); + } + + Optional templateFromTheOnlyArgument = + findUserMessageTemplateFromTheOnlyArgument(method.getParameters(), args); + if (templateFromTheOnlyArgument.isPresent()) { + return templateFromTheOnlyArgument.get(); + } + + throw illegalConfiguration("Error: The method '%s' does not have a user message defined.", method.getName()); + } + + private static Optional findUserMessageTemplateFromMethodAnnotation(Method method) { + return Optional.ofNullable(method.getAnnotation(dev.langchain4j.service.UserMessage.class)) + .map(a -> getTemplate(method, "User", a.fromResource(), a.value(), a.delimiter())); + } + + private static Optional findUserMessageTemplateFromAnnotatedParameter( + Parameter[] parameters, Object[] args) { + for (int i = 0; i < parameters.length; i++) { + if (parameters[i].isAnnotationPresent(dev.langchain4j.service.UserMessage.class)) { + return Optional.of(toString(args[i])); + } + } + return Optional.empty(); + } + + private static Optional findUserMessageTemplateFromTheOnlyArgument(Parameter[] parameters, Object[] args) { + if (parameters != null && parameters.length == 1 && parameters[0].getAnnotations().length == 0) { + return Optional.of(toString(args[0])); + } + return Optional.empty(); + } + + private static Optional findUserName(Parameter[] parameters, Object[] args) { + for (int i = 0; i < parameters.length; i++) { + if (parameters[i].isAnnotationPresent(UserName.class)) { + return Optional.of(args[i].toString()); + } + } + return Optional.empty(); + } + + private static String getTemplate(Method method, String type, String resource, String[] value, String delimiter) { + String messageTemplate; + if (!resource.trim().isEmpty()) { + messageTemplate = getResourceText(method.getDeclaringClass(), resource); + if (messageTemplate == null) { + throw illegalConfiguration("@%sMessage's resource '%s' not found", type, resource); + } + } else { + messageTemplate = String.join(delimiter, value); + } + if (messageTemplate.trim().isEmpty()) { + throw illegalConfiguration("@%sMessage's template cannot be empty", type); + } + return messageTemplate; + } + + private static String getResourceText(Class clazz, String resource) { + InputStream inputStream = clazz.getResourceAsStream(resource); + if (inputStream == null) { + inputStream = clazz.getResourceAsStream("/" + resource); + } + return getText(inputStream); + } + + private static String getText(InputStream inputStream) { + if (inputStream == null) { + return null; + } + try (Scanner scanner = new Scanner(inputStream); + Scanner s = scanner.useDelimiter("\\A")) { + return s.hasNext() ? s.next() : ""; + } + } + + private static Optional findMemoryId(Method method, Object[] args) { + Parameter[] parameters = method.getParameters(); + for (int i = 0; i < parameters.length; i++) { + if (parameters[i].isAnnotationPresent(MemoryId.class)) { + Object memoryId = args[i]; + if (memoryId == null) { + throw illegalArgument( + "The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null", + parameters[i].getName(), method.getName()); + } + return Optional.of(memoryId); + } + } + return Optional.empty(); + } + + private static String toString(Object arg) { + if (arg.getClass().isArray()) { + return arrayToString(arg); + } else if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) { + return StructuredPromptProcessor.toPrompt(arg).text(); + } else { + return arg.toString(); + } + } + + private static String arrayToString(Object arg) { + StringBuilder sb = new StringBuilder("["); + int length = Array.getLength(arg); + for (int i = 0; i < length; i++) { + sb.append(toString(Array.get(arg, i))); + if (i < length - 1) { + sb.append(", "); + } + } + sb.append("]"); + return sb.toString(); + } +} diff --git a/metis-starter/src/main/java/com/metis/llm/ModelEngine.java b/metis-starter/src/main/java/com/metis/llm/service/ModelEngine.java similarity index 96% rename from metis-starter/src/main/java/com/metis/llm/ModelEngine.java rename to metis-starter/src/main/java/com/metis/llm/service/ModelEngine.java index 6fafd26..d6cc6c6 100644 --- a/metis-starter/src/main/java/com/metis/llm/ModelEngine.java +++ b/metis-starter/src/main/java/com/metis/llm/service/ModelEngine.java @@ -1,4 +1,4 @@ -package com.metis.llm; +package com.metis.llm.service; import com.metis.domain.entity.base.Model; diff --git a/metis-starter/src/main/java/com/metis/llm/ModelEngineService.java b/metis-starter/src/main/java/com/metis/llm/service/ModelEngineService.java similarity index 94% rename from metis-starter/src/main/java/com/metis/llm/ModelEngineService.java rename to metis-starter/src/main/java/com/metis/llm/service/ModelEngineService.java index 2463934..098bf7b 100644 --- a/metis-starter/src/main/java/com/metis/llm/ModelEngineService.java +++ b/metis-starter/src/main/java/com/metis/llm/service/ModelEngineService.java @@ -1,4 +1,4 @@ -package com.metis.llm; +package com.metis.llm.service; import com.metis.llm.domain.LLMChatModeConfig; import dev.langchain4j.model.chat.ChatModel; diff --git a/metis-starter/src/main/java/com/metis/llm/ModelService.java b/metis-starter/src/main/java/com/metis/llm/service/ModelService.java similarity index 89% rename from metis-starter/src/main/java/com/metis/llm/ModelService.java rename to metis-starter/src/main/java/com/metis/llm/service/ModelService.java index 0afb5c6..486de7b 100644 --- a/metis-starter/src/main/java/com/metis/llm/ModelService.java +++ b/metis-starter/src/main/java/com/metis/llm/service/ModelService.java @@ -1,4 +1,4 @@ -package com.metis.llm; +package com.metis.llm.service; import com.metis.domain.entity.base.Model; diff --git a/metis-starter/src/main/java/com/metis/llm/impl/ModelEngineServiceImpl.java b/metis-starter/src/main/java/com/metis/llm/service/impl/ModelEngineServiceImpl.java similarity index 93% rename from metis-starter/src/main/java/com/metis/llm/impl/ModelEngineServiceImpl.java rename to metis-starter/src/main/java/com/metis/llm/service/impl/ModelEngineServiceImpl.java index 7bf0815..5d423d3 100644 --- a/metis-starter/src/main/java/com/metis/llm/impl/ModelEngineServiceImpl.java +++ b/metis-starter/src/main/java/com/metis/llm/service/impl/ModelEngineServiceImpl.java @@ -1,12 +1,12 @@ -package com.metis.llm.impl; +package com.metis.llm.service.impl; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.ObjectUtil; import com.metis.domain.entity.base.Model; import com.metis.enums.ModelTypeEnum; -import com.metis.llm.ModelEngine; -import com.metis.llm.ModelEngineService; -import com.metis.llm.ModelService; +import com.metis.llm.service.ModelEngine; +import com.metis.llm.service.ModelEngineService; +import com.metis.llm.service.ModelService; import com.metis.llm.domain.LLMChatModeConfig; import com.metis.llm.domain.LLMEmbeddingModelConfig; import com.metis.llm.factory.ModelEngineFactory; diff --git a/metis-starter/src/main/java/com/metis/llm/impl/ModelServiceImpl.java b/metis-starter/src/main/java/com/metis/llm/service/impl/ModelServiceImpl.java similarity index 92% rename from metis-starter/src/main/java/com/metis/llm/impl/ModelServiceImpl.java rename to metis-starter/src/main/java/com/metis/llm/service/impl/ModelServiceImpl.java index 56f0591..9bdd1a0 100644 --- a/metis-starter/src/main/java/com/metis/llm/impl/ModelServiceImpl.java +++ b/metis-starter/src/main/java/com/metis/llm/service/impl/ModelServiceImpl.java @@ -1,11 +1,11 @@ -package com.metis.llm.impl; +package com.metis.llm.service.impl; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.ObjectUtil; import com.metis.convert.ModelPlatformConvert; import com.metis.domain.entity.ModelPlatform; import com.metis.domain.entity.base.Model; -import com.metis.llm.ModelService; +import com.metis.llm.service.ModelService; import com.metis.service.ModelPlatformService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; diff --git a/metis-starter/src/main/java/com/metis/runner/RunnerResult.java b/metis-starter/src/main/java/com/metis/runner/RunnerResult.java index 3306ea5..dc7160d 100644 --- a/metis-starter/src/main/java/com/metis/runner/RunnerResult.java +++ b/metis-starter/src/main/java/com/metis/runner/RunnerResult.java @@ -2,12 +2,9 @@ package com.metis.runner; import com.alibaba.fastjson2.JSONObject; -import com.metis.domain.context.SysContext; import lombok.Builder; import lombok.Data; -import java.util.Map; - /** * 运行结果 * @@ -26,7 +23,7 @@ public class RunnerResult { /** * 上下文 */ - private SysContext context; +// private SysContext context; } diff --git a/metis-starter/src/main/java/com/metis/runner/impl/EndNodeRunner.java b/metis-starter/src/main/java/com/metis/runner/impl/EndNodeRunner.java index 4aacda8..9f080af 100644 --- a/metis-starter/src/main/java/com/metis/runner/impl/EndNodeRunner.java +++ b/metis-starter/src/main/java/com/metis/runner/impl/EndNodeRunner.java @@ -1,6 +1,7 @@ package com.metis.runner.impl; +import cn.hutool.core.collection.CollUtil; import com.alibaba.fastjson2.JSONObject; import com.metis.domain.context.RunningContext; import com.metis.domain.context.RunningResult; @@ -21,7 +22,15 @@ public class EndNodeRunner implements NodeRunner { @Override public RunningResult run(RunningContext context, Node node, List edges) { JSONObject contextNodeValue = new JSONObject(); - contextNodeValue.put("userId", context.getSys().getAppId()); + EndNodeConfig config = node.getConfig(); + List variables = config.getVariables(); + if (CollUtil.isEmpty(variables)){ + return RunningResult.buildResult(); + } + for (EndNodeConfig.Variable variable : variables) { + Object value = context.getValue(variable.getVariableKey()); + contextNodeValue.put(variable.getVariable(), value); + } return RunningResult.buildResult(contextNodeValue); } diff --git a/metis-starter/src/main/java/com/metis/runner/impl/StartNodeRunner.java b/metis-starter/src/main/java/com/metis/runner/impl/StartNodeRunner.java index a387f62..5c98a38 100644 --- a/metis-starter/src/main/java/com/metis/runner/impl/StartNodeRunner.java +++ b/metis-starter/src/main/java/com/metis/runner/impl/StartNodeRunner.java @@ -39,8 +39,8 @@ public class StartNodeRunner implements NodeRunner { JSONObject custom = context.getCustom(); JSONObject contextNodeValue = new JSONObject(); // 获取到系统上下文, 并将系统上下文放入到start的运行结果中, 用于后续调用 - JSONObject sysContext = getSysContext(context); - contextNodeValue.putAll(sysContext); +// JSONObject sysContext = getSysContext(context); +// contextNodeValue.putAll(sysContext); for (NodeVariable variable : variables) { Object value = variable.getValue(custom); diff --git a/metis-starter/src/main/java/com/metis/runner/impl/LLMNodeRunner.java b/metis-starter/src/main/java/com/metis/runner/impl/llm/LLMNodeRunner.java similarity index 79% rename from metis-starter/src/main/java/com/metis/runner/impl/LLMNodeRunner.java rename to metis-starter/src/main/java/com/metis/runner/impl/llm/LLMNodeRunner.java index f91bd65..a9ca738 100644 --- a/metis-starter/src/main/java/com/metis/runner/impl/LLMNodeRunner.java +++ b/metis-starter/src/main/java/com/metis/runner/impl/llm/LLMNodeRunner.java @@ -1,4 +1,4 @@ -package com.metis.runner.impl; +package com.metis.runner.impl.llm; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjectUtil; @@ -11,8 +11,11 @@ import com.metis.domain.entity.base.Node; import com.metis.domain.entity.config.node.llm.LLMNodeConfig; import com.metis.domain.entity.config.node.llm.PromptTemplate; import com.metis.enums.NodeType; -import com.metis.llm.ModelEngineService; import com.metis.llm.domain.LLMChatModeConfig; +import com.metis.llm.domain.Usage; +import com.metis.llm.service.ChatBoot; +import com.metis.llm.service.FlowAiServices; +import com.metis.llm.service.ModelEngineService; import com.metis.runner.NodeRunner; import com.metis.template.domain.RenderContext; import com.metis.template.utils.VelocityUtil; @@ -21,7 +24,7 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatModel; -import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.Result; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -45,12 +48,20 @@ public class LLMNodeRunner implements NodeRunner { ChatModel chatModel = modelEngineService.getChatLanguageModel(model); List chatMessageList = buildChatMessage(context, config); - ChatResponse chat = chatModel.chat(chatMessageList); - String string = chat.toString(); - log.info("LLM 输出结果: {}", string); + // todo 需要使用FlowAiServices单独扩展build方法中返回数据格式的操作, 可以用在二期or三期进行扩展开发 + ChatBoot chatBoot = FlowAiServices.builder(ChatBoot.class) + .chatModel(chatModel) + .build(); + Result chatResult = chatBoot.chat(chatMessageList); + String text = chatResult.content(); + + + log.info("LLM 输出结果: {}", text); JSONObject result = new JSONObject(); - result.put(BaseConstant.TEXT, string); + result.put(BaseConstant.TEXT, text); + result.put(BaseConstant.USAGE, Usage.buildTokenUsage(chatResult.tokenUsage())); + result.put(BaseConstant.FINISH_REASON, chatResult.finishReason()); return RunningResult.buildResult(result); } diff --git a/metis-starter/src/test/resources/flow.json b/metis-starter/src/test/resources/flow.json index 526b43a..0d4c520 100644 --- a/metis-starter/src/test/resources/flow.json +++ b/metis-starter/src/test/resources/flow.json @@ -1,4 +1,3 @@ - { "appId": 1919041086810968064, "name": "llm运行测试", @@ -6,7 +5,7 @@ "graph": { "nodes": [ { - "id": "5", + "id": "node_5", "type": "start", "initialized": false, "position": { @@ -49,7 +48,7 @@ "height": 40 }, { - "id": "700", + "id": "node_700", "type": "llm", "initialized": false, "position": { @@ -61,7 +60,7 @@ "icon": "", "toolbarPosition": "right", "config": { - "context": "5.background", + "context": "node_5.background", "retryConfig": { "enable": true, "maxRetries": 3, @@ -75,7 +74,7 @@ }, { "role": "user", - "text": "请你解释一下上述问题${5.query}", + "text": "请你解释一下上述问题${node_5.query}", "id": "2" } ], @@ -112,7 +111,7 @@ "height": 40 }, { - "id": "802", + "id": "node_802", "type": "end", "initialized": false, "position": { @@ -123,7 +122,30 @@ "label": "结束", "icon": "", "toolbarPosition": "right", - "config": {}, + "config": { + "variables": [ + { + "variable": "query", + "variableKey": "node_5.query" + }, + { + "variable": "background", + "variableKey": "node_5.background" + }, + { + "variable": "usage", + "variableKey": "node_700.usage" + }, + { + "variable": "finishReason", + "variableKey": "node_700.finishReason" + }, + { + "variable": "text", + "variableKey": "node_700.text" + } + ] + }, "handles": [ { "id": "13", @@ -142,8 +164,8 @@ { "id": "vueflow__edge-551-70057", "type": "default", - "source": "5", - "target": "700", + "source": "node_5", + "target": "node_700", "sourceHandle": "51", "targetHandle": "57", "data": {}, @@ -159,8 +181,8 @@ { "id": "802", "type": "default", - "source": "700", - "target": "802", + "source": "node_700", + "target": "node_802", "sourceHandle": "35", "targetHandle": "13", "data": {}, @@ -185,5 +207,4 @@ "zoom": 3.5988075528449266 } } - } \ No newline at end of file