Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions apps/tools/serializers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ def __init__(self, tool: dict, version: str):
}


class NewUUID:
def __init__(self):
self.uuid_dict = {}

def generate_uuid(self, _id):
_id = str(_id)
if _id in self.uuid_dict:
return self.uuid_dict.get(_id)
r = str(uuid.uuid7())
self.uuid_dict[_id] = r
return r


def to_dict(message, file_name):
return {
'line': message.line,
Expand Down Expand Up @@ -414,7 +427,7 @@ def insert(self, instance, with_valid=True):
'user_id': self.data.get('user_id'),
'workspace_id': self.data.get('workspace_id'),
'folder_id': str(instance.get('folder_id', self.data.get('workspace_id'))),
}).import_(name=instance.get('name'))
}).import_(name=instance.get('name'), source='template')

try:
requests.get(template_instance.get('downloadCallbackUrl'), timeout=5)
Expand Down Expand Up @@ -810,23 +823,27 @@ def import_workflow_tools(self, tool, workspace_id, user_id, folder_id, new_chil
"""
if new_child_policy == 0:
tool_list = []
elif new_child_policy == 1:
tool_list = tool.get('tool_list') or []
else:
tool_list = [{**tool, 'id': str(uuid.uuid7())} for tool in tool.get('tool_list') or []]
tool_list = tool.get('tool_list') or []

tool_list = {tool.get('id'): tool for tool in tool_list}.values()
update_tool_map = {}
if len(tool_list) > 0:
new_uuid = NewUUID()
tool_id_list = reduce(lambda x, y: [*x, *y],
[[tool.get('id'), generate_uuid((tool.get('id') + workspace_id or ''))]
[[tool.get('id'), new_uuid.generate_uuid(
tool.get('id')) if new_child_policy == 2 else generate_uuid(
(tool.get('id') + workspace_id or ''))]
for tool
in
tool_list], [])
# 存在的工具列表
exits_tool_id_list = [str(tool.id) for tool in
QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)]
# 需要更新的工具集合
update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + workspace_id or '')) for tool
update_tool_map = {tool.get('id'): new_uuid.generate_uuid(
tool.get('id')) if new_child_policy == 2 else generate_uuid(
(tool.get('id') + workspace_id or '')) for tool
in
tool_list if
not exits_tool_id_list.__contains__(
Expand All @@ -835,7 +852,9 @@ def import_workflow_tools(self, tool, workspace_id, user_id, folder_id, new_chil
tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if
not exits_tool_id_list.__contains__(
tool.get('id')) and not exits_tool_id_list.__contains__(
generate_uuid((tool.get('id') + workspace_id or '')))]
new_uuid.generate_uuid(
tool.get('id')) if new_child_policy == 2 else generate_uuid(
(tool.get('id') + workspace_id or '')))]

work_flow = self.to_tool_workflow(
tool.get('work_flow'),
Expand Down Expand Up @@ -895,7 +914,7 @@ def update_template_workflow(self, tool_id: str):
return True

@transaction.atomic
def import_(self, scope=ToolScope.WORKSPACE, name=None):
def import_(self, scope=ToolScope.WORKSPACE, name=None, source=None):
self.is_valid()

user_id = self.data.get('user_id')
Expand Down Expand Up @@ -940,7 +959,7 @@ def import_(self, scope=ToolScope.WORKSPACE, name=None):
if tool.get('tool_type') == ToolType.WORKFLOW:
tool['id'] = tool_id
self.import_workflow_tools(tool, workspace_id=self.data.get('workspace_id'), user_id=user_id,
folder_id=folder_id, new_child_policy=1)
folder_id=folder_id, new_child_policy=2 if source == 'template' else 1)
# 自动授权给创建者
UserResourcePermissionSerializer(data={
'workspace_id': self.data.get('workspace_id'),
Expand Down
Loading