Spring AI

Spring AI

概述

Spring AI是Spring生态系统中的人工智能框架,旨在简化AI应用的开发。它提供了与各种AI服务的集成,包括大语言模型、向量数据库、图像生成等,让开发者能够轻松地在Spring应用中集成AI功能。

核心特性

1. 统一的AI抽象

  • 提供一致的API接口
  • 支持多种AI服务提供商
  • 简化AI服务的切换和配置

2. 多模态支持

  • 文本生成和理解
  • 图像生成和识别
  • 语音处理
  • 向量嵌入

3. Spring生态集成

  • 与Spring Boot无缝集成
  • 支持Spring Security
  • 集成Spring Data

支持的AI服务

OpenAI集成

配置

yaml
spring:
  ai:
    openai:
      api-key: ${OPENAI_API_KEY}
      base-url: https://api.openai.com
      chat:
        model: gpt-3.5-turbo
        temperature: 0.7
        max-tokens: 1000
spring:
  ai:
    openai:
      api-key: ${OPENAI_API_KEY}
      base-url: https://api.openai.com
      chat:
        model: gpt-3.5-turbo
        temperature: 0.7
        max-tokens: 1000

聊天模型使用

java
@RestController
public class ChatController {
    
    private final ChatClient chatClient;
    
    public ChatController(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }
    
    @PostMapping("/chat")
    public String chat(@RequestBody String message) {
        return chatClient.prompt()
            .user(message)
            .call()
            .content();
    }
    
    @PostMapping("/chat/stream")
    public Flux<String> chatStream(@RequestBody String message) {
        return chatClient.prompt()
            .user(message)
            .stream()
            .content();
    }
}
@RestController
public class ChatController {
    
    private final ChatClient chatClient;
    
    public ChatController(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }
    
    @PostMapping("/chat")
    public String chat(@RequestBody String message) {
        return chatClient.prompt()
            .user(message)
            .call()
            .content();
    }
    
    @PostMapping("/chat/stream")
    public Flux<String> chatStream(@RequestBody String message) {
        return chatClient.prompt()
            .user(message)
            .stream()
            .content();
    }
}

结构化输出

java
public record PersonInfo(String name, int age, String occupation) {}

@Service
public class PersonExtractionService {
    
    private final ChatClient chatClient;
    
    public PersonExtractionService(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }
    
    public PersonInfo extractPersonInfo(String text) {
        return chatClient.prompt()
            .user("Extract person information from: " + text)
            .call()
            .entity(PersonInfo.class);
    }
}
public record PersonInfo(String name, int age, String occupation) {}

@Service
public class PersonExtractionService {
    
    private final ChatClient chatClient;
    
    public PersonExtractionService(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }
    
    public PersonInfo extractPersonInfo(String text) {
        return chatClient.prompt()
            .user("Extract person information from: " + text)
            .call()
            .entity(PersonInfo.class);
    }
}

向量数据库集成

配置向量存储

yaml
spring:
  ai:
    vectorstore:
      chroma:
        host: localhost
        port: 8000
      pinecone:
        api-key: ${PINECONE_API_KEY}
        environment: us-west1-gcp
        index-name: my-index
spring:
  ai:
    vectorstore:
      chroma:
        host: localhost
        port: 8000
      pinecone:
        api-key: ${PINECONE_API_KEY}
        environment: us-west1-gcp
        index-name: my-index

向量存储使用

java
@Service
public class DocumentService {
    
    private final VectorStore vectorStore;
    private final EmbeddingClient embeddingClient;
    
    public DocumentService(VectorStore vectorStore, EmbeddingClient embeddingClient) {
        this.vectorStore = vectorStore;
        this.embeddingClient = embeddingClient;
    }
    
    public void addDocument(String content, Map<String, Object> metadata) {
        Document document = new Document(content, metadata);
        List<Double> embedding = embeddingClient.embed(content);
        document.setEmbedding(embedding);
        
        vectorStore.add(List.of(document));
    }
    
    public List<Document> searchSimilar(String query, int topK) {
        List<Double> queryEmbedding = embeddingClient.embed(query);
        return vectorStore.similaritySearch(
            SearchRequest.query(query)
                .withTopK(topK)
                .withSimilarityThreshold(0.7)
        );
    }
}
@Service
public class DocumentService {
    
    private final VectorStore vectorStore;
    private final EmbeddingClient embeddingClient;
    
    public DocumentService(VectorStore vectorStore, EmbeddingClient embeddingClient) {
        this.vectorStore = vectorStore;
        this.embeddingClient = embeddingClient;
    }
    
    public void addDocument(String content, Map<String, Object> metadata) {
        Document document = new Document(content, metadata);
        List<Double> embedding = embeddingClient.embed(content);
        document.setEmbedding(embedding);
        
        vectorStore.add(List.of(document));
    }
    
    public List<Document> searchSimilar(String query, int topK) {
        List<Double> queryEmbedding = embeddingClient.embed(query);
        return vectorStore.similaritySearch(
            SearchRequest.query(query)
                .withTopK(topK)
                .withSimilarityThreshold(0.7)
        );
    }
}

RAG(检索增强生成)

RAG实现

java
@Service
public class RagService {
    
    private final ChatClient chatClient;
    private final VectorStore vectorStore;
    
    public RagService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
        this.chatClient = chatClientBuilder.build();
        this.vectorStore = vectorStore;
    }
    
    public String answerQuestion(String question) {
        // 1. 检索相关文档
        List<Document> relevantDocs = vectorStore.similaritySearch(
            SearchRequest.query(question).withTopK(3)
        );
        
        // 2. 构建上下文
        String context = relevantDocs.stream()
            .map(Document::getContent)
            .collect(Collectors.joining("\n\n"));
        
        // 3. 生成回答
        String prompt = """
            Based on the following context, answer the question.
            
            Context:
            %s
            
            Question: %s
            
            Answer:
            """.formatted(context, question);
        
        return chatClient.prompt()
            .user(prompt)
            .call()
            .content();
    }
}
@Service
public class RagService {
    
    private final ChatClient chatClient;
    private final VectorStore vectorStore;
    
    public RagService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
        this.chatClient = chatClientBuilder.build();
        this.vectorStore = vectorStore;
    }
    
    public String answerQuestion(String question) {
        // 1. 检索相关文档
        List<Document> relevantDocs = vectorStore.similaritySearch(
            SearchRequest.query(question).withTopK(3)
        );
        
        // 2. 构建上下文
        String context = relevantDocs.stream()
            .map(Document::getContent)
            .collect(Collectors.joining("\n\n"));
        
        // 3. 生成回答
        String prompt = """
            Based on the following context, answer the question.
            
            Context:
            %s
            
            Question: %s
            
            Answer:
            """.formatted(context, question);
        
        return chatClient.prompt()
            .user(prompt)
            .call()
            .content();
    }
}

图像生成

DALL-E集成

java
@Service
public class ImageGenerationService {
    
    private final ImageClient imageClient;
    
    public ImageGenerationService(ImageClient imageClient) {
        this.imageClient = imageClient;
    }
    
    public String generateImage(String prompt) {
        ImageResponse response = imageClient.call(
            ImagePrompt.builder()
                .withPrompt(prompt)
                .withModel("dall-e-3")
                .withWidth(1024)
                .withHeight(1024)
                .build()
        );
        
        return response.getResult().getOutput().getUrl();
    }
    
    public List<String> generateMultipleImages(String prompt, int count) {
        ImageResponse response = imageClient.call(
            ImagePrompt.builder()
                .withPrompt(prompt)
                .withN(count)
                .build()
        );
        
        return response.getResults().stream()
            .map(result -> result.getOutput().getUrl())
            .collect(Collectors.toList());
    }
}
@Service
public class ImageGenerationService {
    
    private final ImageClient imageClient;
    
    public ImageGenerationService(ImageClient imageClient) {
        this.imageClient = imageClient;
    }
    
    public String generateImage(String prompt) {
        ImageResponse response = imageClient.call(
            ImagePrompt.builder()
                .withPrompt(prompt)
                .withModel("dall-e-3")
                .withWidth(1024)
                .withHeight(1024)
                .build()
        );
        
        return response.getResult().getOutput().getUrl();
    }
    
    public List<String> generateMultipleImages(String prompt, int count) {
        ImageResponse response = imageClient.call(
            ImagePrompt.builder()
                .withPrompt(prompt)
                .withN(count)
                .build()
        );
        
        return response.getResults().stream()
            .map(result -> result.getOutput().getUrl())
            .collect(Collectors.toList());
    }
}

高级功能

函数调用(Function Calling)

定义函数

java
@Component
public class WeatherFunction implements Function<WeatherRequest, WeatherResponse> {
    
    @Override
    public WeatherResponse apply(WeatherRequest request) {
        // 调用天气API获取天气信息
        return weatherService.getWeather(request.getLocation());
    }
}

public record WeatherRequest(String location) {}
public record WeatherResponse(String location, double temperature, String description) {}
@Component
public class WeatherFunction implements Function<WeatherRequest, WeatherResponse> {
    
    @Override
    public WeatherResponse apply(WeatherRequest request) {
        // 调用天气API获取天气信息
        return weatherService.getWeather(request.getLocation());
    }
}

public record WeatherRequest(String location) {}
public record WeatherResponse(String location, double temperature, String description) {}

使用函数调用

java
@Service
public class AssistantService {
    
    private final ChatClient chatClient;
    
    public AssistantService(ChatClient.Builder chatClientBuilder, WeatherFunction weatherFunction) {
        this.chatClient = chatClientBuilder
            .defaultFunctions("weatherFunction")
            .build();
    }
    
    public String handleUserQuery(String query) {
        return chatClient.prompt()
            .user(query)
            .call()
            .content();
    }
}
@Service
public class AssistantService {
    
    private final ChatClient chatClient;
    
    public AssistantService(ChatClient.Builder chatClientBuilder, WeatherFunction weatherFunction) {
        this.chatClient = chatClientBuilder
            .defaultFunctions("weatherFunction")
            .build();
    }
    
    public String handleUserQuery(String query) {
        return chatClient.prompt()
            .user(query)
            .call()
            .content();
    }
}

提示词模板

模板定义

java
@Component
public class PromptTemplates {
    
    private final PromptTemplate summaryTemplate = new PromptTemplate("""
        Please summarize the following text in {language} language.
        Keep the summary under {maxWords} words.
        
        Text to summarize:
        {text}
        
        Summary:
        """);
    
    public Prompt createSummaryPrompt(String text, String language, int maxWords) {
        Map<String, Object> variables = Map.of(
            "text", text,
            "language", language,
            "maxWords", maxWords
        );
        
        return summaryTemplate.create(variables);
    }
}
@Component
public class PromptTemplates {
    
    private final PromptTemplate summaryTemplate = new PromptTemplate("""
        Please summarize the following text in {language} language.
        Keep the summary under {maxWords} words.
        
        Text to summarize:
        {text}
        
        Summary:
        """);
    
    public Prompt createSummaryPrompt(String text, String language, int maxWords) {
        Map<String, Object> variables = Map.of(
            "text", text,
            "language", language,
            "maxWords", maxWords
        );
        
        return summaryTemplate.create(variables);
    }
}

对话记忆

会话管理

java
@Service
public class ConversationService {
    
    private final ChatClient chatClient;
    private final Map<String, List<Message>> conversations = new ConcurrentHashMap<>();
    
    public ConversationService(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }
    
    public String chat(String sessionId, String userMessage) {
        List<Message> history = conversations.computeIfAbsent(sessionId, k -> new ArrayList<>());
        
        // 添加用户消息
        history.add(new UserMessage(userMessage));
        
        // 生成回复
        String response = chatClient.prompt()
            .messages(history)
            .call()
            .content();
        
        // 添加助手回复
        history.add(new AssistantMessage(response));
        
        // 限制历史记录长度
        if (history.size() > 20) {
            history.subList(0, history.size() - 20).clear();
        }
        
        return response;
    }
    
    public void clearConversation(String sessionId) {
        conversations.remove(sessionId);
    }
}
@Service
public class ConversationService {
    
    private final ChatClient chatClient;
    private final Map<String, List<Message>> conversations = new ConcurrentHashMap<>();
    
    public ConversationService(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
    }
    
    public String chat(String sessionId, String userMessage) {
        List<Message> history = conversations.computeIfAbsent(sessionId, k -> new ArrayList<>());
        
        // 添加用户消息
        history.add(new UserMessage(userMessage));
        
        // 生成回复
        String response = chatClient.prompt()
            .messages(history)
            .call()
            .content();
        
        // 添加助手回复
        history.add(new AssistantMessage(response));
        
        // 限制历史记录长度
        if (history.size() > 20) {
            history.subList(0, history.size() - 20).clear();
        }
        
        return response;
    }
    
    public void clearConversation(String sessionId) {
        conversations.remove(sessionId);
    }
}

实际应用案例

智能客服系统

java
@RestController
public class CustomerServiceController {
    
    private final RagService ragService;
    private final ConversationService conversationService;
    
    @PostMapping("/customer-service/chat")
    public ChatResponse chat(@RequestBody ChatRequest request) {
        String sessionId = request.getSessionId();
        String message = request.getMessage();
        
        // 首先尝试从知识库获取答案
        String ragAnswer = ragService.answerQuestion(message);
        
        // 如果RAG没有找到合适答案,使用对话模型
        String finalAnswer;
        if (isGoodAnswer(ragAnswer)) {
            finalAnswer = ragAnswer;
        } else {
            finalAnswer = conversationService.chat(sessionId, message);
        }
        
        return new ChatResponse(finalAnswer, sessionId);
    }
    
    private boolean isGoodAnswer(String answer) {
        // 评估答案质量的逻辑
        return answer != null && !answer.contains("I don't know");
    }
}
@RestController
public class CustomerServiceController {
    
    private final RagService ragService;
    private final ConversationService conversationService;
    
    @PostMapping("/customer-service/chat")
    public ChatResponse chat(@RequestBody ChatRequest request) {
        String sessionId = request.getSessionId();
        String message = request.getMessage();
        
        // 首先尝试从知识库获取答案
        String ragAnswer = ragService.answerQuestion(message);
        
        // 如果RAG没有找到合适答案,使用对话模型
        String finalAnswer;
        if (isGoodAnswer(ragAnswer)) {
            finalAnswer = ragAnswer;
        } else {
            finalAnswer = conversationService.chat(sessionId, message);
        }
        
        return new ChatResponse(finalAnswer, sessionId);
    }
    
    private boolean isGoodAnswer(String answer) {
        // 评估答案质量的逻辑
        return answer != null && !answer.contains("I don't know");
    }
}

文档分析系统

java
@Service
public class DocumentAnalysisService {
    
    private final ChatClient chatClient;
    private final VectorStore vectorStore;
    
    public DocumentAnalysisResult analyzeDocument(MultipartFile file) throws IOException {
        // 1. 提取文档内容
        String content = extractContent(file);
        
        // 2. 分块处理
        List<String> chunks = chunkDocument(content);
        
        // 3. 生成摘要
        String summary = generateSummary(content);
        
        // 4. 提取关键信息
        KeyInformation keyInfo = extractKeyInformation(content);
        
        // 5. 存储到向量数据库
        storeInVectorDatabase(chunks, file.getOriginalFilename());
        
        return new DocumentAnalysisResult(summary, keyInfo);
    }
    
    private String generateSummary(String content) {
        return chatClient.prompt()
            .user("Please provide a concise summary of the following document:\n\n" + content)
            .call()
            .content();
    }
    
    private KeyInformation extractKeyInformation(String content) {
        return chatClient.prompt()
            .user("Extract key information from this document: " + content)
            .call()
            .entity(KeyInformation.class);
    }
}
@Service
public class DocumentAnalysisService {
    
    private final ChatClient chatClient;
    private final VectorStore vectorStore;
    
    public DocumentAnalysisResult analyzeDocument(MultipartFile file) throws IOException {
        // 1. 提取文档内容
        String content = extractContent(file);
        
        // 2. 分块处理
        List<String> chunks = chunkDocument(content);
        
        // 3. 生成摘要
        String summary = generateSummary(content);
        
        // 4. 提取关键信息
        KeyInformation keyInfo = extractKeyInformation(content);
        
        // 5. 存储到向量数据库
        storeInVectorDatabase(chunks, file.getOriginalFilename());
        
        return new DocumentAnalysisResult(summary, keyInfo);
    }
    
    private String generateSummary(String content) {
        return chatClient.prompt()
            .user("Please provide a concise summary of the following document:\n\n" + content)
            .call()
            .content();
    }
    
    private KeyInformation extractKeyInformation(String content) {
        return chatClient.prompt()
            .user("Extract key information from this document: " + content)
            .call()
            .entity(KeyInformation.class);
    }
}

配置和优化

性能优化

yaml
spring:
  ai:
    openai:
      chat:
        options:
          temperature: 0.7
          max-tokens: 1000
          timeout: 30s
      retry:
        max-attempts: 3
        backoff:
          initial-interval: 1s
          multiplier: 2
          max-interval: 10s
spring:
  ai:
    openai:
      chat:
        options:
          temperature: 0.7
          max-tokens: 1000
          timeout: 30s
      retry:
        max-attempts: 3
        backoff:
          initial-interval: 1s
          multiplier: 2
          max-interval: 10s

安全配置

java
@Configuration
@EnableWebSecurity
public class AISecurityConfig {
    
    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        return http
            .authorizeHttpRequests(auth -> auth
                .requestMatchers("/ai/**").hasRole("AI_USER")
                .anyRequest().authenticated()
            )
            .oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults()))
            .build();
    }
}
@Configuration
@EnableWebSecurity
public class AISecurityConfig {
    
    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        return http
            .authorizeHttpRequests(auth -> auth
                .requestMatchers("/ai/**").hasRole("AI_USER")
                .anyRequest().authenticated()
            )
            .oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults()))
            .build();
    }
}

监控和日志

java
@Component
public class AIMetrics {
    
    private final MeterRegistry meterRegistry;
    private final Counter aiRequestCounter;
    private final Timer aiResponseTimer;
    
    public AIMetrics(MeterRegistry meterRegistry) {
        this.meterRegistry = meterRegistry;
        this.aiRequestCounter = Counter.builder("ai.requests.total")
            .description("Total AI requests")
            .register(meterRegistry);
        this.aiResponseTimer = Timer.builder("ai.response.time")
            .description("AI response time")
            .register(meterRegistry);
    }
    
    public void recordRequest(String model, String operation) {
        aiRequestCounter.increment(
            Tags.of("model", model, "operation", operation)
        );
    }
    
    public void recordResponseTime(Duration duration, String model) {
        aiResponseTimer.record(duration, Tags.of("model", model));
    }
}
@Component
public class AIMetrics {
    
    private final MeterRegistry meterRegistry;
    private final Counter aiRequestCounter;
    private final Timer aiResponseTimer;
    
    public AIMetrics(MeterRegistry meterRegistry) {
        this.meterRegistry = meterRegistry;
        this.aiRequestCounter = Counter.builder("ai.requests.total")
            .description("Total AI requests")
            .register(meterRegistry);
        this.aiResponseTimer = Timer.builder("ai.response.time")
            .description("AI response time")
            .register(meterRegistry);
    }
    
    public void recordRequest(String model, String operation) {
        aiRequestCounter.increment(
            Tags.of("model", model, "operation", operation)
        );
    }
    
    public void recordResponseTime(Duration duration, String model) {
        aiResponseTimer.record(duration, Tags.of("model", model));
    }
}

最佳实践

  1. 提示词工程:设计清晰、具体的提示词
  2. 错误处理:处理API限制、网络错误等异常情况
  3. 成本控制:监控API调用次数和token使用量
  4. 数据隐私:确保敏感数据的安全处理
  5. 缓存策略:缓存常见查询的结果
  6. 模型选择:根据任务复杂度选择合适的模型
  7. 性能监控:监控响应时间和成功率

总结

Spring AI为Java开发者提供了一个强大而易用的AI集成框架。通过统一的抽象层,开发者可以轻松地在Spring应用中集成各种AI服务,构建智能化的应用程序。随着AI技术的快速发展,Spring AI将继续演进,为企业级AI应用开发提供更多支持。