dcs 11 小時之前
父節點
當前提交
ea825cca32

+ 28 - 19
virgo.api/src/main/java/com/bosshand/virgo/api/workark/controller/DifyController.java

@@ -1,11 +1,13 @@
 package com.bosshand.virgo.api.workark.controller;
 
+import com.alibaba.fastjson.JSONObject;
 import com.bosshand.virgo.api.workark.model.*;
 import com.bosshand.virgo.api.workark.service.DifyDatasetService;
 import com.bosshand.virgo.api.workark.service.DifyService;
 import com.bosshand.virgo.api.workark.util.SseEmitterUtil;
 import com.bosshand.virgo.core.response.Response;
 import com.bosshand.virgo.exception.Constant;
+import io.github.imfangs.dify.client.exception.DifyApiException;
 import io.swagger.annotations.*;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.web.bind.annotation.*;
@@ -13,7 +15,6 @@ import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
 import java.io.IOException;
-import java.io.InputStream;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
@@ -212,7 +213,13 @@ public class DifyController {
     @ApiOperation("创建知识库")
     @RequestMapping(value = "/dataset", method = RequestMethod.POST)
     public Response createDataset(@RequestBody DifyDataset difyDataset) {
-        difyDatasetService.createDataset(difyDataset);
+        try {
+            difyDatasetService.createDataset(difyDataset);
+        } catch (IOException e) {
+            return Response.fail(Constant.CODE_BAD_REQUEST, e.getMessage());
+        } catch (DifyApiException e) {
+            return Response.fail(Constant.CODE_BAD_REQUEST, e.getMessage());
+        }
         return Response.ok();
     }
 
@@ -227,12 +234,20 @@ public class DifyController {
     }
 
     @ApiOperation("获取文档列表")
-    @RequestMapping(value = "/datasets/file/{datasetId}", method = RequestMethod.GET)
+    @RequestMapping(value = "/datasets/file/{datasetId}/{page}/{limit}", method = RequestMethod.GET)
     @ApiImplicitParams({
-            @ApiImplicitParam(name = "datasetId", value = "知识库id")
+            @ApiImplicitParam(name = "datasetId", value = "知识库id"),
+            @ApiImplicitParam(name = "page", value = "页码"),
+            @ApiImplicitParam(name = "limit", value = "返回条数"),
     })
-    public Response getFileList(@PathVariable String datasetId) {
-        return Response.ok(difyDatasetService.getDocument(datasetId));
+    public Response getFileList(@PathVariable String datasetId,@PathVariable int page,@PathVariable int limit) {
+        try {
+            return Response.ok(difyDatasetService.getDocument(datasetId, page, limit));
+        } catch (DifyApiException e) {
+            return Response.fail(Constant.CODE_BAD_REQUEST, e.getMessage());
+        } catch (IOException e) {
+            return Response.fail(Constant.CODE_BAD_REQUEST, e.getMessage());
+        }
     }
 
     @ApiOperation("删除文档")
@@ -250,8 +265,12 @@ public class DifyController {
     @ApiImplicitParams({
             @ApiImplicitParam(name = "datasetId", value = "知识库id")
     })
-    public Response retrieveDataset(@PathVariable String datasetId, @RequestBody RetrieveDatasetDto retrieveDatasetDto) {
-        return Response.ok(difyDatasetService.retrieveDataset(datasetId, retrieveDatasetDto));
+    public Response retrieveDataset(@PathVariable String datasetId, @RequestBody JSONObject jsonobject) {
+        try {
+            return Response.ok(difyDatasetService.retrieveDataset(datasetId, jsonobject));
+        } catch (IOException e) {
+            return Response.fail(Constant.CODE_BAD_REQUEST, Constant.RET_INPUT_ERROR);
+        }
     }
 
     private static final List<String> ALLOWED_TYPES = Arrays.asList("txt", "markdown", "md", "mdx", "pdf", "html", "xlsx", "xls", "docx", "csv", "vtt", "properties", "htm");
@@ -268,25 +287,15 @@ public class DifyController {
         if (!ALLOWED_TYPES.contains(suffix)) {
             return Response.fail(Constant.CODE_BAD_REQUEST, "文件格式不支持");
         }
-
         // 验证文件大小(15MB限制)
         if (file.getSize() > 15 * 1024 * 1024) {
             return Response.fail(Constant.CODE_BAD_REQUEST, "文件大小不能超过15MB");
         }
-        InputStream inputStream = null;
         try {
-            inputStream = file.getInputStream();
-            difyDatasetService.createDocumentByFile(inputStream, file.getOriginalFilename(), datasetId);
+            difyDatasetService.uploadFileToDataset(datasetId, file);
             return Response.ok();
         } catch (IOException e) {
             return Response.fail(Constant.CODE_BAD_REQUEST, Constant.RET_INPUT_ERROR);
-        } finally {
-            if (inputStream != null) {
-                try {
-                    inputStream.close();
-                } catch (IOException e) {
-                }
-            }
         }
     }
 

+ 107 - 15
virgo.api/src/main/java/com/bosshand/virgo/api/workark/service/DifyDatasetService.java

@@ -1,20 +1,23 @@
 package com.bosshand.virgo.api.workark.service;
 
+import com.alibaba.fastjson.JSONObject;
 import com.bosshand.virgo.api.workark.dao.DifyDatasetDao;
 import com.bosshand.virgo.api.workark.dao.DifyDatasetDocumentDao;
 import com.bosshand.virgo.api.workark.model.DifyDataset;
 import com.bosshand.virgo.api.workark.model.DifyDatasetDocument;
-import com.bosshand.virgo.api.workark.model.RetrieveDatasetDto;
 import io.github.imfangs.dify.client.DifyClientFactory;
 import io.github.imfangs.dify.client.DifyDatasetsClient;
 import io.github.imfangs.dify.client.exception.DifyApiException;
 import io.github.imfangs.dify.client.model.datasets.*;
+import okhttp3.*;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
+import org.springframework.web.multipart.MultipartFile;
 
 import java.io.IOException;
 import java.io.InputStream;
 import java.util.List;
+import java.util.Objects;
 
 @Service
 public class DifyDatasetService {
@@ -45,9 +48,9 @@ public class DifyDatasetService {
                 getClient().deleteDataset(datasetId);
                 difyDatasetDao.delete(datasetId);
                 difyDatasetDocumentDao.deleteDatasetId(datasetId);
-                System.out.println("删除测试知识库成功,ID: " + datasetId);
+                System.out.println("删除知识库成功,ID: " + datasetId);
             } catch (Exception e) {
-                System.err.println("删除测试知识库失败: " + e.getMessage());
+                System.err.println("删除知识库失败: " + e.getMessage());
             }
         }
     }
@@ -55,7 +58,7 @@ public class DifyDatasetService {
     /**
      * 创建知识库
      */
-    public void createDataset(DifyDataset difyDataset) {
+    public void createDataset(DifyDataset difyDataset) throws IOException, DifyApiException {
         // 创建知识库请求
         CreateDatasetRequest request = CreateDatasetRequest.builder()
                 .name(difyDataset.getName())
@@ -64,15 +67,9 @@ public class DifyDatasetService {
                 .build();
 
         // 发送请求
-        try {
-            DatasetResponse response = getClient().createDataset(request);
-            difyDataset.setDatasetId(response.getId());
-            difyDatasetDao.save(difyDataset);
-        } catch (IOException e) {
-            e.printStackTrace();
-        } catch (DifyApiException e) {
-            e.printStackTrace();
-        }
+        DatasetResponse response = getClient().createDataset(request);
+        difyDataset.setDatasetId(response.getId());
+        difyDatasetDao.save(difyDataset);
     }
 
     /**
@@ -199,10 +196,12 @@ public class DifyDatasetService {
     /**
      * 获取文档列表
      */
-    public List<DifyDatasetDocument> getDocument(String datasetId) {
-        return difyDatasetDocumentDao.getDatasetId(datasetId);
+    public DocumentListResponse getDocument(String datasetId, int page, int limit) throws DifyApiException, IOException {
+        DocumentListResponse response = getClient().getDocuments(datasetId, null, page, limit);
+        return response;
     }
 
+    /**
     public RetrieveResponse retrieveDataset(String datasetId, RetrieveDatasetDto dto) {
 
         String query = dto.getQuery();
@@ -224,6 +223,99 @@ public class DifyDatasetService {
         }
         return null;
     }
+    */
+
+    private static final OkHttpClient client = new OkHttpClient();
+
+    public String uploadFile(String apiUrl, String apiKey, String datasetId, MultipartFile multipartFile) throws IOException {
+
+        String finalUrl = apiUrl.replace("{dataset_id}", datasetId);
+
+        // 构建请求体
+        RequestBody requestBody = new MultipartBody.Builder()
+                .setType(MultipartBody.FORM)
+                // 添加 JSON 数据部分
+                .addFormDataPart(
+                        "data",
+                        null,
+                        RequestBody.create(
+                                MediaType.parse("text/plain"),
+                                "{\"indexing_technique\": \"economy\",\"process_rule\": {\"mode\": \"automatic\"}}"
+                        )
+                )
+                // 添加文件部分
+                .addFormDataPart(
+                        "file",
+                        Objects.requireNonNull(multipartFile.getOriginalFilename()),
+                        RequestBody.create(
+                                MediaType.parse(Objects.requireNonNull(multipartFile.getContentType())),
+                                multipartFile.getBytes()
+                        )
+                )
+                .build();
+
+        // 构建请求
+        Request request = new Request.Builder()
+                .url(finalUrl)
+                .header("Authorization", "Bearer " + apiKey)
+                .post(requestBody)
+                .build();
+
+        // 发送请求并获取响应
+        try (Response response = client.newCall(request).execute()) {
+            if (!response.isSuccessful()) {
+                throw new IOException("请求失败: " + response.code() + " " + response.message());
+            }
+            return Objects.requireNonNull(response.body()).string();
+        }
+    }
+
+    public void uploadFileToDataset(String datasetId, MultipartFile multipartFile) throws IOException {
+        String apiUrl = "http://203.110.233.149:80/v1/datasets/{dataset_id}/document/create-by-file";
+        String response = uploadFile(apiUrl, "dataset-SWjJp6FOFqT85n7KxxyCFPSS", datasetId, multipartFile);
+        DocumentResponse documentResponse = JSONObject.parseObject(response, DocumentResponse.class);
+        // 保存文档ID
+        String documentId = documentResponse.getDocument().getId();
+        DifyDatasetDocument difyDatasetDocument = new DifyDatasetDocument();
+        difyDatasetDocument.setDatasetId(datasetId);
+        difyDatasetDocument.setDocumentId(documentId);
+        difyDatasetDocument.setName(multipartFile.getOriginalFilename());
+        difyDatasetDocumentDao.save(difyDatasetDocument);
+    }
+
+    public String retrieve(String apiUrl, String apiKey, String datasetId, String json) throws IOException {
+
+        String finalUrl = apiUrl.replace("{dataset_id}", datasetId);
+
+        // 构建请求体
+        RequestBody body = RequestBody.create(MediaType.parse("application/json"), json);
+
+        // 构建请求
+        Request request = new Request.Builder()
+                .url(finalUrl)
+                .header("Authorization", "Bearer " + apiKey)
+                .post(body)
+                .build();
+
+        // 发送请求并获取响应
+        try (Response response = client.newCall(request).execute()) {
+            if (!response.isSuccessful()) {
+                throw new IOException("请求失败: " + response.code() + " " + response.message());
+            }
+            return Objects.requireNonNull(response.body()).string();
+        }
+
+    }
+
+    public RetrieveResponse retrieveDataset(String datasetId, JSONObject jsonobject) throws IOException {
+        String apiUrl = "http://203.110.233.149:80/v1/datasets/{dataset_id}/retrieve";
+        String json = JSONObject.toJSONString(jsonobject);
+        String response = retrieve(apiUrl, "dataset-SWjJp6FOFqT85n7KxxyCFPSS", datasetId, json);
+        RetrieveResponse retrieveResponse = JSONObject.parseObject(response, RetrieveResponse.class);
+        return retrieveResponse;
+    }
+
+
 
 
 }