feat: 整体项目架构完成, 运行的核心算法完成, 自定义节点starter以外定义节点测试通过

This commit is contained in:
2025-04-22 22:28:28 +08:00
parent 65e6f9f650
commit ff992b8903
29 changed files with 969 additions and 191 deletions

View File

@@ -44,6 +44,11 @@ public class RunningContext {
this.nodeRunningContext.put(nodeId, nodeRunningContext);
}
public JSONObject getRunningContext(Long nodeId) {
return this.nodeRunningContext.get(nodeId);
}
/**
* 构建上下文
*
@@ -58,4 +63,6 @@ public class RunningContext {
.build();
}
}

View File

@@ -35,6 +35,18 @@ public class RunningResult {
.build();
}
/**
* 构建结果
*
* @param nextRunNodeId 下一个运行节点id
* @return {@link RunningResult }
*/
public static RunningResult buildResult(Set<Long> nextRunNodeId) {
return RunningResult.builder()
.nextRunNodeId(nextRunNodeId)
.build();
}
/**
* 构建结果
*

View File

@@ -1,52 +0,0 @@
package com.metis.domain.entity;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Node;
import java.util.*;
public class GraphDemo {
private Map<Long, Node> nodes = new HashMap<>();
private Map<Long, List<Long>> adjacencyList = new HashMap<>();
public void addNode(Node node) {
nodes.put(node.getId(), node);
adjacencyList.put(node.getId(), new ArrayList<>());
}
public void addEdge(Edge edge) {
adjacencyList.get(edge.getSource())
.add(edge.getTarget());
}
public List<Node> topologicalSort() {
List<Node> sortedNodes = new ArrayList<>();
Set<Long> visited = new HashSet<>();
Set<Long> visiting = new HashSet<>();
for (Long nodeId : nodes.keySet()) {
if (!visited.contains(nodeId)) {
dfs(nodeId, visited, visiting, sortedNodes);
}
}
Collections.reverse(sortedNodes);
return sortedNodes;
}
private void dfs(Long nodeId, Set<Long> visited, Set<Long> visiting, List<Node> sortedNodes) {
if (visiting.contains(nodeId)) {
throw new IllegalStateException("Cycle detected in the graph");
}
if (!visited.contains(nodeId)) {
visiting.add(nodeId);
for (Long neighbor : adjacencyList.get(nodeId)) {
dfs(neighbor, visited, visiting, sortedNodes);
}
visiting.remove(nodeId);
visited.add(nodeId);
sortedNodes.add(nodes.get(nodeId));
}
}
}

View File

@@ -0,0 +1,131 @@
package com.metis.domain.entity;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Graph;
import com.metis.domain.entity.base.Node;
import com.metis.enums.NodeType;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
public class GraphDto {
private final Map<Long, Node> nodeMap;
private final Map<Long, Boolean> nodeReadyMap;
private final Map<Long, List<Edge>> edgeMap;
private final Map<Long, List<Long>> adjacencyList = new HashMap<>();
private final List<Node> sortedNodes = new ArrayList<>();
public List<Node> getSortedNodes() {
return new ArrayList<>(sortedNodes);
}
public List<Edge> getEdgeNodeId(Long nodeId) {
return edgeMap.getOrDefault(nodeId, new ArrayList<>());
}
public Node getEndNode(){
return sortedNodes.stream()
.filter(node -> NodeType.END.equals(node.getType()))
.findFirst()
.orElse(null);
}
public Node getNode(Long nodeId) {
return nodeMap.get(nodeId);
}
public void updateNodeReadyMap(Long nodeId, Boolean ready) {
nodeReadyMap.put(nodeId, ready);
}
public Boolean isNodeReady(Long nodeId) {
return nodeReadyMap.get(nodeId);
}
private GraphDto(List<Node> nodes, List<Edge> edges) {
this.edgeMap = edges.stream()
.collect(Collectors.groupingBy(Edge::getSource));
this.nodeMap = nodes.stream()
.collect(Collectors.toMap(Node::getId, Function.identity()));
this.nodeReadyMap = nodes.stream()
.collect(Collectors.toMap(Node::getId, node -> false));
initAdjacencyList(edges);
List<Node> nodeList = topologicalSort();
this.sortedNodes.addAll(nodeList);
Node node = sortedNodes.get(0);
if (NodeType.START.equals(node.getType())) {
nodeReadyMap.put(node.getId(), true);
}
}
private void initAdjacencyList(List<Edge> edges) {
for (Edge edge : edges) {
List<Long> targetList = adjacencyList.getOrDefault(edge.getSource(), new ArrayList<>());
targetList.add(edge.getTarget());
adjacencyList.put(edge.getSource(), targetList);
}
}
/**
* 拓扑排序
*
* @return {@link List }<{@link Node }>
*/
private List<Node> topologicalSort() {
List<Node> sortedNodes = new ArrayList<>();
Set<Long> visited = new HashSet<>();
Set<Long> visiting = new HashSet<>();
for (Long nodeId : nodeMap.keySet()) {
if (!visited.contains(nodeId)) {
dfs(nodeId, visited, visiting, sortedNodes);
}
}
Collections.reverse(sortedNodes);
return sortedNodes;
}
/**
* 深度遍历找到运行顺序
*
* @param nodeId 节点id
* @param visited 参观了
* @param visiting 参观
* @param sortedNodes 排序节点
*/
private void dfs(Long nodeId, Set<Long> visited, Set<Long> visiting, List<Node> sortedNodes) {
if (visiting.contains(nodeId)) {
throw new IllegalStateException("Cycle detected in the graph");
}
if (!visited.contains(nodeId)) {
visiting.add(nodeId);
for (Long neighbor : adjacencyList.getOrDefault(nodeId, new ArrayList<>())) {
dfs(neighbor, visited, visiting, sortedNodes);
}
visiting.remove(nodeId);
visited.add(nodeId);
sortedNodes.add(nodeMap.get(nodeId));
}
}
/**
* 构建对象
*
* @param graph 图
* @return {@link GraphDto }
*/
public static GraphDto of(Graph graph) {
return new GraphDto(graph.getNodes(), graph.getEdges());
}
}

View File

@@ -0,0 +1,10 @@
package com.metis.domain.entity.config.node;
import com.metis.domain.entity.base.NodeConfig;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
public class LLMNodeConfig extends NodeConfig {
}

View File

@@ -0,0 +1,10 @@
package com.metis.domain.entity.config.node;
import com.metis.domain.entity.base.NodeConfig;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
public class QuestionClassifierConfig extends NodeConfig {
}

View File

@@ -68,6 +68,7 @@ public class AppEngineServiceImpl implements AppEngineService {
}
@Override
@Transactional(rollbackFor = Exception.class)
public App create(CreateApp createApp) {
BuildApp buildApp = BaseAppConvert.INSTANCE.toBuildApp(createApp);
// 校验

View File

@@ -5,12 +5,14 @@ import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject;
import com.metis.domain.context.RunningContext;
import com.metis.domain.context.RunningResult;
import com.metis.domain.context.SysContext;
import com.metis.domain.entity.App;
import com.metis.domain.entity.GraphDto;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Graph;
import com.metis.domain.entity.base.Node;
import com.metis.engine.AppEngineService;
import com.metis.engine.AppFlowEngineRunnerService;
@@ -25,8 +27,6 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
@Slf4j
@Service
@@ -40,7 +40,7 @@ public class AppFlowEngineRunnerServiceImpl implements AppFlowEngineRunnerServic
public RunnerResult running(FlowRunningContext context) {
App app = getApp(context);
Assert.isTrue(ObjectUtil.isNotNull(app), "app为空");
// todo 构建运行实例, 并将运行实例放入上下文
// 构建运行实例, 并将运行实例放入上下文
Long instanceId = IdUtil.getSnowflakeNextId();
// 构建系统上下文信息
SysContext sysContext = SysContext.builder()
@@ -52,26 +52,49 @@ public class AppFlowEngineRunnerServiceImpl implements AppFlowEngineRunnerServic
// 构建运行中上下文
RunningContext runningContext = RunningContext.buildContext(sysContext, context);
// 构建节点映射对象
Graph graph = app.getGraph();
Map<Long, Node> nodeMap = graph.getNodes().stream()
.collect(Collectors.toMap(Node::getId, Function.identity()));
Map<Long, List<Edge>> edgeMap = graph.getEdges().stream()
.collect(Collectors.groupingBy(Edge::getSource));
GraphDto graph = GraphDto.of(app.getGraph());
Set<Node> readyRunningNode = new HashSet<>();
// 获取到开始节点
// 开始节点为空,则表示数据存在异常
Assert.isTrue(ObjectUtil.isNotNull(readyRunningNode), "流程图不存在开始节点");
while (CollUtil.isNotEmpty(readyRunningNode)) {
// todo 出现多个节点同时运行, 需要找到他们最终运行的聚合节点, 前期默认只有一条线路运行, 不支持并行流程
doRunning(readyRunningNode, edgeMap, runningContext);
readyRunningNode = null;
for (Node node : graph.getSortedNodes()) {
Long nodeId = node.getId();
if (!graph.isNodeReady(nodeId)) {
continue;
}
log.info("当前运行节点 id:{}, name:{}, type:{}", node.getId(), node.getData().getLabel(), node.getType());
// 当前节点接下来的连接线信息
List<Edge> edges = graph.getEdgeNodeId(nodeId);
// 执行
NodeRunner nodeRunner = getNodeRunner(node);
node.setConfigClass(GenericInterfacesUtils.getClass(nodeRunner));
// 下一个需要运行的节点id加入到可以运行的节点中
// 获取到返回结果
RunningResult result = nodeRunner.run(runningContext, node, edges);
log.info("节点执行结果:{}", JSON.toJSONString(result));
// 节点执行结果参数放入上下文中
if (ObjectUtil.isNotNull(result.getNodeContext())) {
runningContext.addNodeRunningContext(node.getId(), result.getNodeContext());
}
// 下一个需要运行的节点id加入到可以运行的节点中
if (CollUtil.isNotEmpty(result.getNextRunNodeId())) {
for (Long nextNodeId : result.getNextRunNodeId()) {
graph.updateNodeReadyMap(nextNodeId, true);
}
} else {
// 如果没有返回, 则认为所有的下级节点都需要运行
edges.forEach(edge -> {
graph.updateNodeReadyMap(edge.getTarget(), true);
});
}
}
Node endNode = graph.getEndNode();
JSONObject endRunningContext = runningContext.getRunningContext(endNode.getId());
return RunnerResult.builder()
.content("你他妈的!")
.result(endRunningContext)
.context(sysContext)
.build();
@@ -113,7 +136,7 @@ public class AppFlowEngineRunnerServiceImpl implements AppFlowEngineRunnerServic
* @return {@link NodeRunner }
*/
private NodeRunner getNodeRunner(Node node) {
if (NodeType.CUSTOM_NODE.equals(node.getType())) {
if (NodeType.CUSTOM.equals(node.getType())) {
Assert.isTrue(StrUtil.isNotBlank(node.getCustomType()), "自定义节点类型不能为空");
return NodeRunnerFactory.getCustom(node.getCustomType());
}

View File

@@ -13,8 +13,14 @@ public enum NodeType {
START(1, "start", "开始"),
END(2, "end", "结束"),
DOCUMENT_EXTRACTOR(3, "document-extractor", "文档提取器"),
CUSTOM_NODE(4, "Custom-Node", "自定义节点");
DOCUMENT_EXTRACTOR(3, "documentExtractor", "文档提取器"),
CUSTOM(4, "custom", "自定义节点"),
LLM(5, "llm", "LLM"),
QUESTION_CLASSIFIER(6, "questionClassifier", "问题分类器"),
IF_ELSE(7, "ifElse", "条件判断"),
;
private final Integer code;
@@ -24,8 +30,6 @@ public enum NodeType {
private final String name;
// private final Class<?> configClass;
/**
* 枚举序列化器(前端传code时自动转换为对应枚举)

View File

@@ -1,8 +1,8 @@
package com.metis.facade;
import com.metis.domain.bo.ProcessBo;
import com.metis.convert.GraphConvert;
import com.metis.domain.bo.CreateApp;
import com.metis.domain.bo.ProcessBo;
import com.metis.domain.bo.UpdateApp;
import com.metis.domain.entity.App;
import com.metis.domain.entity.base.Graph;

View File

@@ -26,7 +26,7 @@ public interface CustomNodeRunner<T extends NodeConfig> extends NodeRunner<T> {
* @return {@link NodeType }
*/
default NodeType getType() {
return NodeType.CUSTOM_NODE;
return NodeType.CUSTOM;
}

View File

@@ -1,5 +1,6 @@
package com.metis.runner;
import cn.hutool.core.collection.CollUtil;
import com.metis.domain.context.RunningContext;
import com.metis.domain.context.RunningResult;
import com.metis.domain.entity.base.Edge;
@@ -8,6 +9,8 @@ import com.metis.domain.entity.base.NodeConfig;
import com.metis.enums.NodeType;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* 内置节点运行器
@@ -36,4 +39,18 @@ public interface NodeRunner<T extends NodeConfig> {
*/
NodeType getType();
/**
* 获取下一个节点id
*
* @param edges 边缘
* @return {@link Set }<{@link Long }>
*/
default Set<Long> getNextNodeIds(List<Edge> edges) {
if (CollUtil.isEmpty(edges)) {
return Set.of();
}
return edges.stream().map(Edge::getTarget).collect(Collectors.toSet());
}
}

View File

@@ -1,10 +1,13 @@
package com.metis.runner;
import com.alibaba.fastjson2.JSONObject;
import com.metis.domain.context.SysContext;
import lombok.Builder;
import lombok.Data;
import java.util.Map;
/**
* 运行结果
*
@@ -18,7 +21,7 @@ public class RunnerResult {
/**
* 运行内容
*/
private String content;
private JSONObject result;
/**
* 上下文

View File

@@ -26,7 +26,7 @@ public class RunnerInitialize implements ApplicationContextAware {
Map<String, NodeRunner> runnerMap = applicationContext.getBeansOfType(NodeRunner.class);
runnerMap.forEach((runnerBeanName, runner) -> {
if (NodeType.CUSTOM_NODE.equals(runner.getType())) {
if (NodeType.CUSTOM.equals(runner.getType())) {
Assert.isTrue(runner instanceof CustomNodeRunner, "自定义节点必须实现CustomNodeRunner接口");
NodeRunnerFactory.registerCustom((CustomNodeRunner) runner);
} else {

View File

@@ -1,6 +1,7 @@
package com.metis.runner.impl;
import com.alibaba.fastjson2.JSONObject;
import com.metis.domain.context.RunningContext;
import com.metis.domain.context.RunningResult;
import com.metis.domain.entity.base.Edge;
@@ -19,7 +20,9 @@ public class EndNodeRunner implements NodeRunner<EndNodeConfig> {
@Override
public RunningResult run(RunningContext context, Node node, List<Edge> edges) {
return RunningResult.buildResult();
JSONObject contextNodeValue = new JSONObject();
contextNodeValue.put("userId", context.getSys().getAppId());
return RunningResult.buildResult(contextNodeValue);
}
@Override

View File

@@ -0,0 +1,27 @@
package com.metis.runner.impl;
import com.metis.domain.context.RunningContext;
import com.metis.domain.context.RunningResult;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Node;
import com.metis.domain.entity.config.node.LLMNodeConfig;
import com.metis.enums.NodeType;
import com.metis.runner.NodeRunner;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Slf4j
@Service
public class LLMNodeRunner implements NodeRunner<LLMNodeConfig> {
@Override
public RunningResult run(RunningContext context, Node node, List<Edge> edges) {
return RunningResult.buildResult();
}
@Override
public NodeType getType() {
return NodeType.LLM;
}
}

View File

@@ -0,0 +1,36 @@
package com.metis.runner.impl;
import com.metis.domain.context.RunningContext;
import com.metis.domain.context.RunningResult;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Node;
import com.metis.domain.entity.config.node.QuestionClassifierConfig;
import com.metis.enums.NodeType;
import com.metis.runner.NodeRunner;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Set;
@Slf4j
@Service
public class QuestionClassifierRunner implements NodeRunner<QuestionClassifierConfig> {
@Override
public RunningResult run(RunningContext context, Node node, List<Edge> edges) {
Set<Long> nextNodeIds = getNextNodeIds(edges);
// 生成随机索引
Random random = new Random();
int randomIndex = random.nextInt(nextNodeIds.size());
List<Long> nodeIds = new ArrayList<>(nextNodeIds);
return RunningResult.buildResult(Set.of(nodeIds.get(randomIndex)));
}
@Override
public NodeType getType() {
return NodeType.QUESTION_CLASSIFIER;
}
}

View File

@@ -27,7 +27,6 @@ public class StartNodeRunner implements NodeRunner<StartNodeConfig> {
@Override
public RunningResult run(RunningContext context, Node node, List<Edge> edges) {
log.info("开始节点{}, 节点id: {} 运行", node.getData().getLabel(), node.getId());
StartNodeConfig config = node.getConfig();
// 获取到节点的自定义参数
List<NodeVariable> variables = config.getVariables();
@@ -46,7 +45,6 @@ public class StartNodeRunner implements NodeRunner<StartNodeConfig> {
}
@Override
public NodeType getType() {
return NodeType.START;

View File

@@ -26,7 +26,7 @@ public interface CustomNodeValidator<T extends NodeConfig> extends NodeValidator
* @return {@link NodeType }
*/
default NodeType getType() {
return NodeType.CUSTOM_NODE;
return NodeType.CUSTOM;
}

View File

@@ -30,7 +30,7 @@ public class ValidatorInitialize implements ApplicationContextAware {
Map<String, NodeValidator> nodeMap = applicationContext.getBeansOfType(NodeValidator.class);
nodeMap.forEach((nodeValidatorBeanName, nodeValidator) -> {
if (NodeType.CUSTOM_NODE.equals(nodeValidator.getType())) {
if (NodeType.CUSTOM.equals(nodeValidator.getType())) {
Assert.isTrue(nodeValidator instanceof CustomNodeValidator, "自定义节点必须实现CustomNodeValidator接口");
NodeValidatorFactory.registerCustom((CustomNodeValidator) nodeValidator);
} else {

View File

@@ -258,7 +258,7 @@ public class ValidatorServiceImpl implements ValidatorService {
* @return {@link NodeValidator }
*/
private NodeValidator getNodeValidator(Node node) {
if (NodeType.CUSTOM_NODE.equals(node.getType())) {
if (NodeType.CUSTOM.equals(node.getType())) {
Assert.isTrue(StrUtil.isNotBlank(node.getCustomType()), "自定义节点类型不能为空");
return NodeValidatorFactory.getCustom(node.getCustomType());
}

View File

@@ -31,8 +31,8 @@ public class EndNodeValidator implements NodeValidator<EndNodeConfig> {
Assert.isTrue(targets.isEmpty(), "结束节点 {} 不允许有目标连接", node.getId());
// 2. 检查 sources 数量是否小于 handles 数量
int handleCount = node.getData().getHandles().size();
Assert.isTrue(sources.size() <= handleCount, "结束节点 {} 的源连接数超过 handles 数量", node.getId());
// int handleCount = node.getData().getHandles().size();
// Assert.isTrue(sources.size() <= handleCount, "结束节点 {} 的源连接数超过 handles 数量", node.getId());
return ValidatorResult.valid();
}

View File

@@ -0,0 +1,33 @@
package com.metis.validator.impl.node;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Node;
import com.metis.domain.entity.config.node.LLMNodeConfig;
import com.metis.enums.NodeType;
import com.metis.validator.NodeValidator;
import com.metis.validator.ValidatorResult;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class LLMNodeValidator implements NodeValidator<LLMNodeConfig> {
@Override
public ValidatorResult validateValue(Node node) {
return ValidatorResult.valid();
}
@Override
public ValidatorResult validateRelation(Node node, List<Edge> sources, List<Edge> targets) {
return ValidatorResult.valid();
}
@Override
public NodeType getType() {
return NodeType.LLM;
}
}

View File

@@ -0,0 +1,31 @@
package com.metis.validator.impl.node;
import com.metis.domain.entity.base.Edge;
import com.metis.domain.entity.base.Node;
import com.metis.domain.entity.config.node.QuestionClassifierConfig;
import com.metis.enums.NodeType;
import com.metis.validator.NodeValidator;
import com.metis.validator.ValidatorResult;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Slf4j
@Service
public class QuestionClassifierValidator implements NodeValidator<QuestionClassifierConfig> {
@Override
public ValidatorResult validateValue(Node node) {
return ValidatorResult.valid();
}
@Override
public ValidatorResult validateRelation(Node node, List<Edge> sources, List<Edge> targets) {
return ValidatorResult.valid();
}
@Override
public NodeType getType() {
return NodeType.QUESTION_CLASSIFIER;
}
}

View File

@@ -84,8 +84,8 @@ public class StartNodeValidator implements NodeValidator<StartNodeConfig> {
Assert.isTrue(sources.isEmpty(), "开始节点 {} 不允许有源连接", node.getId());
// 2. 检查 targets 数量是否小于 handles 数量
int handleCount = node.getData().getHandles().size();
Assert.isTrue(targets.size() <= handleCount, "开始节点 {} 的目标连接数超过 handles 数量", node.getId());
// int handleCount = node.getData().getHandles().size();
// Assert.isTrue(targets.size() <= handleCount, "开始节点 {} 的目标连接数超过 handles 数量", node.getId());
return ValidatorResult.valid();
}