mirror of
				https://github.com/YunaiV/ruoyi-vue-pro.git
				synced 2025-10-31 18:49:06 +08:00 
			
		
		
		
	【增加】对接 Midjourney,增加nonce传递,更新Midjourney image 状态
This commit is contained in:
		| @ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult; | |||||||
| import cn.iocoder.yudao.module.ai.service.AiImageService; | import cn.iocoder.yudao.module.ai.service.AiImageService; | ||||||
| import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; | import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; | ||||||
| import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; | import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; | ||||||
| import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes; |  | ||||||
| import io.swagger.v3.oas.annotations.Operation; | import io.swagger.v3.oas.annotations.Operation; | ||||||
| import io.swagger.v3.oas.annotations.tags.Tag; | import io.swagger.v3.oas.annotations.tags.Tag; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| @ -42,7 +41,8 @@ public class AiImageController { | |||||||
|  |  | ||||||
|     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") |     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") | ||||||
|     @PostMapping("/midjourney") |     @PostMapping("/midjourney") | ||||||
|     public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) { |     public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) { | ||||||
|         return CommonResult.success(aiImageService.midjourney(req)); |         aiImageService.midjourney(req); | ||||||
|  |         return CommonResult.success(null); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -28,5 +28,5 @@ public interface AiImageService { | |||||||
|      * @param req |      * @param req | ||||||
|      * @return |      * @return | ||||||
|      */ |      */ | ||||||
|     AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req); |     void midjourney(AiImageMidjourneyReq req); | ||||||
| } | } | ||||||
|  | |||||||
| @ -95,18 +95,15 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     @Transactional(rollbackFor = Exception.class) |     @Transactional(rollbackFor = Exception.class) | ||||||
|     public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) { |     public void midjourney(AiImageMidjourneyReq req) { | ||||||
|         // 保存数据库 |         // 保存数据库 | ||||||
|         doSave(req.getPrompt(), null, "midjoureny", |         AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny", | ||||||
|                 null, AiChatDrawingStatusEnum.SUBMIT, null); |                 null, AiChatDrawingStatusEnum.SUBMIT, null); | ||||||
|         // 提交 midjourney 任务 |         // 提交 midjourney 任务 | ||||||
|         Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt()); |         Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt()); | ||||||
|         if (!imagine) { |         if (!imagine) { | ||||||
|             throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); |             throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); | ||||||
|         } |         } | ||||||
|         // |  | ||||||
|  |  | ||||||
|         return null; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { |     private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { | ||||||
| @ -120,7 +117,7 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private void doSave(String prompt, |     private AiImageDO doSave(String prompt, | ||||||
|                         String size, |                         String size, | ||||||
|                         String model, |                         String model, | ||||||
|                         String imageUrl, |                         String imageUrl, | ||||||
| @ -138,5 +135,6 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|         aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus()); |         aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus()); | ||||||
|         aiImageDO.setDrawingError(drawingError); |         aiImageDO.setDrawingError(drawingError); | ||||||
|         aiImageMapper.insert(aiImageDO); |         aiImageMapper.insert(aiImageDO); | ||||||
|  |         return aiImageDO; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,7 +1,15 @@ | |||||||
| package cn.iocoder.yudao.module.ai.service.midjourneyHandler; | package cn.iocoder.yudao.module.ai.service.midjourneyHandler; | ||||||
|  |  | ||||||
|  | import cn.hutool.core.collection.CollUtil; | ||||||
|  | import cn.hutool.core.util.StrUtil; | ||||||
| import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; | import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; | ||||||
|  | import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum; | ||||||
| import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler; | import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler; | ||||||
|  | import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO; | ||||||
|  | import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; | ||||||
|  | import cn.iocoder.yudao.module.ai.mapper.AiImageMapper; | ||||||
|  | import com.alibaba.fastjson2.JSON; | ||||||
|  | import lombok.AllArgsConstructor; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.springframework.stereotype.Component; | import org.springframework.stereotype.Component; | ||||||
|  |  | ||||||
| @ -14,10 +22,51 @@ import org.springframework.stereotype.Component; | |||||||
|  */ |  */ | ||||||
| @Component | @Component | ||||||
| @Slf4j | @Slf4j | ||||||
|  | @AllArgsConstructor | ||||||
| public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler { | public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler { | ||||||
|  |  | ||||||
|  |     private final AiImageMapper aiImageMapper; | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     public void messageHandler(MidjourneyMessage midjourneyMessage) { |     public void messageHandler(MidjourneyMessage midjourneyMessage) { | ||||||
|         log.info("yudao-midjourney-midjourney-message-handler", midjourneyMessage); |         log.info("yudao-midjourney-midjourney-message-handler {}", JSON.toJSONString(midjourneyMessage)); | ||||||
|  |         if (midjourneyMessage.getContent() != null) { | ||||||
|  |             log.info("进度id {} 状态 {} 进度 {}", | ||||||
|  |                     midjourneyMessage.getNonce(), | ||||||
|  |                     midjourneyMessage.getGenerateStatus(), | ||||||
|  |                     midjourneyMessage.getContent().getProgress()); | ||||||
|  |         } | ||||||
|  |         // | ||||||
|  |         updateImage(midjourneyMessage); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private void updateImage(MidjourneyMessage midjourneyMessage) { | ||||||
|  |         // Nonce 不存在不更新 | ||||||
|  |         if (StrUtil.isBlank(midjourneyMessage.getNonce())) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         // 获取id | ||||||
|  |         Long aiImageId = Long.valueOf(midjourneyMessage.getNonce()); | ||||||
|  |         // 获取生成 url | ||||||
|  |         String imageUrl = null; | ||||||
|  |         if (CollUtil.isNotEmpty(midjourneyMessage.getAttachments())) { | ||||||
|  |             imageUrl = midjourneyMessage.getAttachments().get(0).getUrl(); | ||||||
|  |         } | ||||||
|  |         // 转换状态 | ||||||
|  |         AiChatDrawingStatusEnum drawingStatusEnum = null; | ||||||
|  |         String generateStatus = midjourneyMessage.getGenerateStatus(); | ||||||
|  |         if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) { | ||||||
|  |             drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE; | ||||||
|  |         } else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) { | ||||||
|  |             drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS; | ||||||
|  |         }  else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) { | ||||||
|  |             drawingStatusEnum = AiChatDrawingStatusEnum.WAITING; | ||||||
|  |         } | ||||||
|  |         aiImageMapper.updateById( | ||||||
|  |                 new AiImageDO() | ||||||
|  |                         .setId(aiImageId) | ||||||
|  |                         .setDrawingImageUrl(imageUrl) | ||||||
|  |                         .setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus()) | ||||||
|  |         ); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -14,6 +14,10 @@ public class MidjourneyMessage { | |||||||
| 	 * id是一个重要的字段,在同时生成多个的时候,可以区分生成信息 | 	 * id是一个重要的字段,在同时生成多个的时候,可以区分生成信息 | ||||||
| 	 */ | 	 */ | ||||||
| 	private String id; | 	private String id; | ||||||
|  | 	/** | ||||||
|  | 	 * 提交id(nonce 可能会不存在,系统提示的时候,这个为空) | ||||||
|  | 	 */ | ||||||
|  | 	private String nonce; | ||||||
| 	/** | 	/** | ||||||
| 	 * 现在已知: | 	 * 现在已知: | ||||||
| 	 * 0:我们发送的消息,和指令 | 	 * 0:我们发送的消息,和指令 | ||||||
| @ -45,6 +49,14 @@ public class MidjourneyMessage { | |||||||
| 	 * {@link MidjourneyGennerateStatusEnum} | 	 * {@link MidjourneyGennerateStatusEnum} | ||||||
| 	 */ | 	 */ | ||||||
| 	private String generateStatus; | 	private String generateStatus; | ||||||
|  | 	/** | ||||||
|  | 	 * 一般用于提示信息 | ||||||
|  | 	 * - 错误 | ||||||
|  | 	 * - 并发队列满了 | ||||||
|  | 	 * - 账号违规了、敏感词 | ||||||
|  | 	 * - 账号被封 | ||||||
|  | 	 */ | ||||||
|  | 	private List<Embed> embeds; | ||||||
|  |  | ||||||
| 	@Data | 	@Data | ||||||
| 	@Accessors(chain = true) | 	@Accessors(chain = true) | ||||||
| @ -123,4 +135,39 @@ public class MidjourneyMessage { | |||||||
| 		private String progress; | 		private String progress; | ||||||
| 		private String status; | 		private String status; | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	/** | ||||||
|  | 	 * embed 用于警告、提示、错误 | ||||||
|  | 	 */ | ||||||
|  | 	@Data | ||||||
|  | 	@Accessors(chain = true) | ||||||
|  | 	public static class Embed { | ||||||
|  |  | ||||||
|  | 		// 内容扫描版本号 | ||||||
|  | 		private int contentScanVersion; | ||||||
|  |  | ||||||
|  | 		// 颜色值,这里用Java的Color类来表示,注意实际使用中可能需要自定义方法来从int转换为Color对象 | ||||||
|  | 		private String color; | ||||||
|  |  | ||||||
|  | 		// 页脚信息,包含文本 | ||||||
|  | 		private Footer footer; | ||||||
|  |  | ||||||
|  | 		// 描述信息 | ||||||
|  | 		private String description; | ||||||
|  |  | ||||||
|  | 		// 消息类型,这里是富文本类型(这个区分不同提示类型) | ||||||
|  | 		private String type; | ||||||
|  |  | ||||||
|  | 		// 标题 | ||||||
|  | 		private String title; | ||||||
|  |  | ||||||
|  | 		// Footer类,作为嵌套类存在,用来表示footer部分的JSON对象 | ||||||
|  | 		@Data | ||||||
|  | 		@Accessors(chain = true) | ||||||
|  | 		public static class Footer { | ||||||
|  | 			// 页脚文本 | ||||||
|  | 			private String text; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -38,11 +38,13 @@ public class MidjourneyInteractionsApi extends MidjourneyInteractions { | |||||||
|         this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); |         this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     public Boolean imagine(String prompt) { |     public Boolean imagine(Long id, String prompt) { | ||||||
|  |         String nonce = String.valueOf(id); | ||||||
|         // 获取请求模板 |         // 获取请求模板 | ||||||
|         String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine"); |         String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine"); | ||||||
|         // 设置参数 |         // 设置参数 | ||||||
|         HashMap<String, String> requestParams = getDefaultParams(); |         HashMap<String, String> requestParams = getDefaultParams(); | ||||||
|  |         requestParams.put("nonce", nonce); | ||||||
|         requestParams.put("prompt", prompt); |         requestParams.put("prompt", prompt); | ||||||
|         // 解析 template 参数占位符 |         // 解析 template 参数占位符 | ||||||
|         String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); |         String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); | ||||||
|  | |||||||
| @ -6,6 +6,10 @@ public final class MidjourneyConstants { | |||||||
| 	 * 消息 - 编号 | 	 * 消息 - 编号 | ||||||
| 	 */ | 	 */ | ||||||
| 	public static final String MSG_ID = "id"; | 	public static final String MSG_ID = "id"; | ||||||
|  | 	/** | ||||||
|  | 	 * 用于区分操作唯一性 | ||||||
|  | 	 */ | ||||||
|  | 	public static final String MSG_NONCE = "nonce"; | ||||||
| 	/** | 	/** | ||||||
| 	 * 消息 - 类型 | 	 * 消息 - 类型 | ||||||
| 	 * 现在已知: | 	 * 现在已知: | ||||||
| @ -32,6 +36,10 @@ public final class MidjourneyConstants { | |||||||
| 	 * 附件(生成中比较模糊的图片) | 	 * 附件(生成中比较模糊的图片) | ||||||
| 	 */ | 	 */ | ||||||
| 	public static final String MSG_ATTACHMENTS = "attachments"; | 	public static final String MSG_ATTACHMENTS = "attachments"; | ||||||
|  | 	/** | ||||||
|  | 	 * 一般用于提示 | ||||||
|  | 	 */ | ||||||
|  | 	public static final String MSG_EMBEDS = "embeds"; | ||||||
|  |  | ||||||
|  |  | ||||||
| 	// | 	// | ||||||
|  | |||||||
| @ -42,12 +42,14 @@ public class MidjourneyMessageListener { | |||||||
|         if (ignoreAndLogMessage(data, messageType)) { |         if (ignoreAndLogMessage(data, messageType)) { | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |         log.info("socket message: {}", raw); | ||||||
|         // 转换几个重要的信息 |         // 转换几个重要的信息 | ||||||
|         MidjourneyMessage mjMessage = new MidjourneyMessage(); |         MidjourneyMessage mjMessage = new MidjourneyMessage(); | ||||||
|         mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID)); |         mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, "")); | ||||||
|  |         mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, "")); | ||||||
|         mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE)); |         mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE)); | ||||||
|         mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8")); |         mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8")); | ||||||
| 		mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT))); |         mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT))); | ||||||
|         // 转换 components |         // 转换 components | ||||||
|         if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) { |         if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) { | ||||||
|             String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8"); |             String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8"); | ||||||
| @ -60,6 +62,12 @@ public class MidjourneyMessageListener { | |||||||
|             List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class); |             List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class); | ||||||
|             mjMessage.setAttachments(attachments); |             mjMessage.setAttachments(attachments); | ||||||
|         } |         } | ||||||
|  |         // 转换 embeds 提示信息 | ||||||
|  |         if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) { | ||||||
|  |             String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8"); | ||||||
|  |             List<MidjourneyMessage.Embed> embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class); | ||||||
|  |             mjMessage.setEmbeds(embeds); | ||||||
|  |         } | ||||||
|         // 转换状态 |         // 转换状态 | ||||||
|         convertGenerateStatus(mjMessage); |         convertGenerateStatus(mjMessage); | ||||||
|         // message handler 调用 |         // message handler 调用 | ||||||
| @ -68,7 +76,20 @@ public class MidjourneyMessageListener { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     private String getString(DataObject data, String key, String defaultValue) { | ||||||
|  |         if (!data.hasKey(key)) { | ||||||
|  |             return defaultValue; | ||||||
|  |         } | ||||||
|  |         return data.getString(key); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     private void convertGenerateStatus(MidjourneyMessage mjMessage) { |     private void convertGenerateStatus(MidjourneyMessage mjMessage) { | ||||||
|  |         // | ||||||
|  |         // tip:提示、警告、异常 content是没有内容的 | ||||||
|  |         // tip: 一般错误信息在 Embeds 只要 Embeds有值,content就没信息。 | ||||||
|  |         if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|         if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) { |         if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) { | ||||||
|             mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus()); |             mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus()); | ||||||
|         } else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) { |         } else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) { | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 cherishsince
					cherishsince