Spring AI
Spring AI
概述
Spring AI是Spring生态系统中的人工智能框架,旨在简化AI应用的开发。它提供了与各种AI服务的集成,包括大语言模型、向量数据库、图像生成等,让开发者能够轻松地在Spring应用中集成AI功能。
核心特性
1. 统一的AI抽象
- 提供一致的API接口
- 支持多种AI服务提供商
- 简化AI服务的切换和配置
2. 多模态支持
- 文本生成和理解
- 图像生成和识别
- 语音处理
- 向量嵌入
3. Spring生态集成
- 与Spring Boot无缝集成
- 支持Spring Security
- 集成Spring Data
支持的AI服务
OpenAI集成
配置
yaml
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
base-url: https://api.openai.com
chat:
model: gpt-3.5-turbo
temperature: 0.7
max-tokens: 1000
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
base-url: https://api.openai.com
chat:
model: gpt-3.5-turbo
temperature: 0.7
max-tokens: 1000
聊天模型使用
java
@RestController
public class ChatController {
private final ChatClient chatClient;
public ChatController(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
@PostMapping("/chat")
public String chat(@RequestBody String message) {
return chatClient.prompt()
.user(message)
.call()
.content();
}
@PostMapping("/chat/stream")
public Flux<String> chatStream(@RequestBody String message) {
return chatClient.prompt()
.user(message)
.stream()
.content();
}
}
@RestController
public class ChatController {
private final ChatClient chatClient;
public ChatController(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
@PostMapping("/chat")
public String chat(@RequestBody String message) {
return chatClient.prompt()
.user(message)
.call()
.content();
}
@PostMapping("/chat/stream")
public Flux<String> chatStream(@RequestBody String message) {
return chatClient.prompt()
.user(message)
.stream()
.content();
}
}
结构化输出
java
public record PersonInfo(String name, int age, String occupation) {}
@Service
public class PersonExtractionService {
private final ChatClient chatClient;
public PersonExtractionService(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
public PersonInfo extractPersonInfo(String text) {
return chatClient.prompt()
.user("Extract person information from: " + text)
.call()
.entity(PersonInfo.class);
}
}
public record PersonInfo(String name, int age, String occupation) {}
@Service
public class PersonExtractionService {
private final ChatClient chatClient;
public PersonExtractionService(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
public PersonInfo extractPersonInfo(String text) {
return chatClient.prompt()
.user("Extract person information from: " + text)
.call()
.entity(PersonInfo.class);
}
}
向量数据库集成
配置向量存储
yaml
spring:
ai:
vectorstore:
chroma:
host: localhost
port: 8000
pinecone:
api-key: ${PINECONE_API_KEY}
environment: us-west1-gcp
index-name: my-index
spring:
ai:
vectorstore:
chroma:
host: localhost
port: 8000
pinecone:
api-key: ${PINECONE_API_KEY}
environment: us-west1-gcp
index-name: my-index
向量存储使用
java
@Service
public class DocumentService {
private final VectorStore vectorStore;
private final EmbeddingClient embeddingClient;
public DocumentService(VectorStore vectorStore, EmbeddingClient embeddingClient) {
this.vectorStore = vectorStore;
this.embeddingClient = embeddingClient;
}
public void addDocument(String content, Map<String, Object> metadata) {
Document document = new Document(content, metadata);
List<Double> embedding = embeddingClient.embed(content);
document.setEmbedding(embedding);
vectorStore.add(List.of(document));
}
public List<Document> searchSimilar(String query, int topK) {
List<Double> queryEmbedding = embeddingClient.embed(query);
return vectorStore.similaritySearch(
SearchRequest.query(query)
.withTopK(topK)
.withSimilarityThreshold(0.7)
);
}
}
@Service
public class DocumentService {
private final VectorStore vectorStore;
private final EmbeddingClient embeddingClient;
public DocumentService(VectorStore vectorStore, EmbeddingClient embeddingClient) {
this.vectorStore = vectorStore;
this.embeddingClient = embeddingClient;
}
public void addDocument(String content, Map<String, Object> metadata) {
Document document = new Document(content, metadata);
List<Double> embedding = embeddingClient.embed(content);
document.setEmbedding(embedding);
vectorStore.add(List.of(document));
}
public List<Document> searchSimilar(String query, int topK) {
List<Double> queryEmbedding = embeddingClient.embed(query);
return vectorStore.similaritySearch(
SearchRequest.query(query)
.withTopK(topK)
.withSimilarityThreshold(0.7)
);
}
}
RAG(检索增强生成)
RAG实现
java
@Service
public class RagService {
private final ChatClient chatClient;
private final VectorStore vectorStore;
public RagService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
this.chatClient = chatClientBuilder.build();
this.vectorStore = vectorStore;
}
public String answerQuestion(String question) {
// 1. 检索相关文档
List<Document> relevantDocs = vectorStore.similaritySearch(
SearchRequest.query(question).withTopK(3)
);
// 2. 构建上下文
String context = relevantDocs.stream()
.map(Document::getContent)
.collect(Collectors.joining("\n\n"));
// 3. 生成回答
String prompt = """
Based on the following context, answer the question.
Context:
%s
Question: %s
Answer:
""".formatted(context, question);
return chatClient.prompt()
.user(prompt)
.call()
.content();
}
}
@Service
public class RagService {
private final ChatClient chatClient;
private final VectorStore vectorStore;
public RagService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
this.chatClient = chatClientBuilder.build();
this.vectorStore = vectorStore;
}
public String answerQuestion(String question) {
// 1. 检索相关文档
List<Document> relevantDocs = vectorStore.similaritySearch(
SearchRequest.query(question).withTopK(3)
);
// 2. 构建上下文
String context = relevantDocs.stream()
.map(Document::getContent)
.collect(Collectors.joining("\n\n"));
// 3. 生成回答
String prompt = """
Based on the following context, answer the question.
Context:
%s
Question: %s
Answer:
""".formatted(context, question);
return chatClient.prompt()
.user(prompt)
.call()
.content();
}
}
图像生成
DALL-E集成
java
@Service
public class ImageGenerationService {
private final ImageClient imageClient;
public ImageGenerationService(ImageClient imageClient) {
this.imageClient = imageClient;
}
public String generateImage(String prompt) {
ImageResponse response = imageClient.call(
ImagePrompt.builder()
.withPrompt(prompt)
.withModel("dall-e-3")
.withWidth(1024)
.withHeight(1024)
.build()
);
return response.getResult().getOutput().getUrl();
}
public List<String> generateMultipleImages(String prompt, int count) {
ImageResponse response = imageClient.call(
ImagePrompt.builder()
.withPrompt(prompt)
.withN(count)
.build()
);
return response.getResults().stream()
.map(result -> result.getOutput().getUrl())
.collect(Collectors.toList());
}
}
@Service
public class ImageGenerationService {
private final ImageClient imageClient;
public ImageGenerationService(ImageClient imageClient) {
this.imageClient = imageClient;
}
public String generateImage(String prompt) {
ImageResponse response = imageClient.call(
ImagePrompt.builder()
.withPrompt(prompt)
.withModel("dall-e-3")
.withWidth(1024)
.withHeight(1024)
.build()
);
return response.getResult().getOutput().getUrl();
}
public List<String> generateMultipleImages(String prompt, int count) {
ImageResponse response = imageClient.call(
ImagePrompt.builder()
.withPrompt(prompt)
.withN(count)
.build()
);
return response.getResults().stream()
.map(result -> result.getOutput().getUrl())
.collect(Collectors.toList());
}
}
高级功能
函数调用(Function Calling)
定义函数
java
@Component
public class WeatherFunction implements Function<WeatherRequest, WeatherResponse> {
@Override
public WeatherResponse apply(WeatherRequest request) {
// 调用天气API获取天气信息
return weatherService.getWeather(request.getLocation());
}
}
public record WeatherRequest(String location) {}
public record WeatherResponse(String location, double temperature, String description) {}
@Component
public class WeatherFunction implements Function<WeatherRequest, WeatherResponse> {
@Override
public WeatherResponse apply(WeatherRequest request) {
// 调用天气API获取天气信息
return weatherService.getWeather(request.getLocation());
}
}
public record WeatherRequest(String location) {}
public record WeatherResponse(String location, double temperature, String description) {}
使用函数调用
java
@Service
public class AssistantService {
private final ChatClient chatClient;
public AssistantService(ChatClient.Builder chatClientBuilder, WeatherFunction weatherFunction) {
this.chatClient = chatClientBuilder
.defaultFunctions("weatherFunction")
.build();
}
public String handleUserQuery(String query) {
return chatClient.prompt()
.user(query)
.call()
.content();
}
}
@Service
public class AssistantService {
private final ChatClient chatClient;
public AssistantService(ChatClient.Builder chatClientBuilder, WeatherFunction weatherFunction) {
this.chatClient = chatClientBuilder
.defaultFunctions("weatherFunction")
.build();
}
public String handleUserQuery(String query) {
return chatClient.prompt()
.user(query)
.call()
.content();
}
}
提示词模板
模板定义
java
@Component
public class PromptTemplates {
private final PromptTemplate summaryTemplate = new PromptTemplate("""
Please summarize the following text in {language} language.
Keep the summary under {maxWords} words.
Text to summarize:
{text}
Summary:
""");
public Prompt createSummaryPrompt(String text, String language, int maxWords) {
Map<String, Object> variables = Map.of(
"text", text,
"language", language,
"maxWords", maxWords
);
return summaryTemplate.create(variables);
}
}
@Component
public class PromptTemplates {
private final PromptTemplate summaryTemplate = new PromptTemplate("""
Please summarize the following text in {language} language.
Keep the summary under {maxWords} words.
Text to summarize:
{text}
Summary:
""");
public Prompt createSummaryPrompt(String text, String language, int maxWords) {
Map<String, Object> variables = Map.of(
"text", text,
"language", language,
"maxWords", maxWords
);
return summaryTemplate.create(variables);
}
}
对话记忆
会话管理
java
@Service
public class ConversationService {
private final ChatClient chatClient;
private final Map<String, List<Message>> conversations = new ConcurrentHashMap<>();
public ConversationService(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
public String chat(String sessionId, String userMessage) {
List<Message> history = conversations.computeIfAbsent(sessionId, k -> new ArrayList<>());
// 添加用户消息
history.add(new UserMessage(userMessage));
// 生成回复
String response = chatClient.prompt()
.messages(history)
.call()
.content();
// 添加助手回复
history.add(new AssistantMessage(response));
// 限制历史记录长度
if (history.size() > 20) {
history.subList(0, history.size() - 20).clear();
}
return response;
}
public void clearConversation(String sessionId) {
conversations.remove(sessionId);
}
}
@Service
public class ConversationService {
private final ChatClient chatClient;
private final Map<String, List<Message>> conversations = new ConcurrentHashMap<>();
public ConversationService(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
}
public String chat(String sessionId, String userMessage) {
List<Message> history = conversations.computeIfAbsent(sessionId, k -> new ArrayList<>());
// 添加用户消息
history.add(new UserMessage(userMessage));
// 生成回复
String response = chatClient.prompt()
.messages(history)
.call()
.content();
// 添加助手回复
history.add(new AssistantMessage(response));
// 限制历史记录长度
if (history.size() > 20) {
history.subList(0, history.size() - 20).clear();
}
return response;
}
public void clearConversation(String sessionId) {
conversations.remove(sessionId);
}
}
实际应用案例
智能客服系统
java
@RestController
public class CustomerServiceController {
private final RagService ragService;
private final ConversationService conversationService;
@PostMapping("/customer-service/chat")
public ChatResponse chat(@RequestBody ChatRequest request) {
String sessionId = request.getSessionId();
String message = request.getMessage();
// 首先尝试从知识库获取答案
String ragAnswer = ragService.answerQuestion(message);
// 如果RAG没有找到合适答案,使用对话模型
String finalAnswer;
if (isGoodAnswer(ragAnswer)) {
finalAnswer = ragAnswer;
} else {
finalAnswer = conversationService.chat(sessionId, message);
}
return new ChatResponse(finalAnswer, sessionId);
}
private boolean isGoodAnswer(String answer) {
// 评估答案质量的逻辑
return answer != null && !answer.contains("I don't know");
}
}
@RestController
public class CustomerServiceController {
private final RagService ragService;
private final ConversationService conversationService;
@PostMapping("/customer-service/chat")
public ChatResponse chat(@RequestBody ChatRequest request) {
String sessionId = request.getSessionId();
String message = request.getMessage();
// 首先尝试从知识库获取答案
String ragAnswer = ragService.answerQuestion(message);
// 如果RAG没有找到合适答案,使用对话模型
String finalAnswer;
if (isGoodAnswer(ragAnswer)) {
finalAnswer = ragAnswer;
} else {
finalAnswer = conversationService.chat(sessionId, message);
}
return new ChatResponse(finalAnswer, sessionId);
}
private boolean isGoodAnswer(String answer) {
// 评估答案质量的逻辑
return answer != null && !answer.contains("I don't know");
}
}
文档分析系统
java
@Service
public class DocumentAnalysisService {
private final ChatClient chatClient;
private final VectorStore vectorStore;
public DocumentAnalysisResult analyzeDocument(MultipartFile file) throws IOException {
// 1. 提取文档内容
String content = extractContent(file);
// 2. 分块处理
List<String> chunks = chunkDocument(content);
// 3. 生成摘要
String summary = generateSummary(content);
// 4. 提取关键信息
KeyInformation keyInfo = extractKeyInformation(content);
// 5. 存储到向量数据库
storeInVectorDatabase(chunks, file.getOriginalFilename());
return new DocumentAnalysisResult(summary, keyInfo);
}
private String generateSummary(String content) {
return chatClient.prompt()
.user("Please provide a concise summary of the following document:\n\n" + content)
.call()
.content();
}
private KeyInformation extractKeyInformation(String content) {
return chatClient.prompt()
.user("Extract key information from this document: " + content)
.call()
.entity(KeyInformation.class);
}
}
@Service
public class DocumentAnalysisService {
private final ChatClient chatClient;
private final VectorStore vectorStore;
public DocumentAnalysisResult analyzeDocument(MultipartFile file) throws IOException {
// 1. 提取文档内容
String content = extractContent(file);
// 2. 分块处理
List<String> chunks = chunkDocument(content);
// 3. 生成摘要
String summary = generateSummary(content);
// 4. 提取关键信息
KeyInformation keyInfo = extractKeyInformation(content);
// 5. 存储到向量数据库
storeInVectorDatabase(chunks, file.getOriginalFilename());
return new DocumentAnalysisResult(summary, keyInfo);
}
private String generateSummary(String content) {
return chatClient.prompt()
.user("Please provide a concise summary of the following document:\n\n" + content)
.call()
.content();
}
private KeyInformation extractKeyInformation(String content) {
return chatClient.prompt()
.user("Extract key information from this document: " + content)
.call()
.entity(KeyInformation.class);
}
}
配置和优化
性能优化
yaml
spring:
ai:
openai:
chat:
options:
temperature: 0.7
max-tokens: 1000
timeout: 30s
retry:
max-attempts: 3
backoff:
initial-interval: 1s
multiplier: 2
max-interval: 10s
spring:
ai:
openai:
chat:
options:
temperature: 0.7
max-tokens: 1000
timeout: 30s
retry:
max-attempts: 3
backoff:
initial-interval: 1s
multiplier: 2
max-interval: 10s
安全配置
java
@Configuration
@EnableWebSecurity
public class AISecurityConfig {
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
return http
.authorizeHttpRequests(auth -> auth
.requestMatchers("/ai/**").hasRole("AI_USER")
.anyRequest().authenticated()
)
.oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults()))
.build();
}
}
@Configuration
@EnableWebSecurity
public class AISecurityConfig {
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
return http
.authorizeHttpRequests(auth -> auth
.requestMatchers("/ai/**").hasRole("AI_USER")
.anyRequest().authenticated()
)
.oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults()))
.build();
}
}
监控和日志
java
@Component
public class AIMetrics {
private final MeterRegistry meterRegistry;
private final Counter aiRequestCounter;
private final Timer aiResponseTimer;
public AIMetrics(MeterRegistry meterRegistry) {
this.meterRegistry = meterRegistry;
this.aiRequestCounter = Counter.builder("ai.requests.total")
.description("Total AI requests")
.register(meterRegistry);
this.aiResponseTimer = Timer.builder("ai.response.time")
.description("AI response time")
.register(meterRegistry);
}
public void recordRequest(String model, String operation) {
aiRequestCounter.increment(
Tags.of("model", model, "operation", operation)
);
}
public void recordResponseTime(Duration duration, String model) {
aiResponseTimer.record(duration, Tags.of("model", model));
}
}
@Component
public class AIMetrics {
private final MeterRegistry meterRegistry;
private final Counter aiRequestCounter;
private final Timer aiResponseTimer;
public AIMetrics(MeterRegistry meterRegistry) {
this.meterRegistry = meterRegistry;
this.aiRequestCounter = Counter.builder("ai.requests.total")
.description("Total AI requests")
.register(meterRegistry);
this.aiResponseTimer = Timer.builder("ai.response.time")
.description("AI response time")
.register(meterRegistry);
}
public void recordRequest(String model, String operation) {
aiRequestCounter.increment(
Tags.of("model", model, "operation", operation)
);
}
public void recordResponseTime(Duration duration, String model) {
aiResponseTimer.record(duration, Tags.of("model", model));
}
}
最佳实践
- 提示词工程:设计清晰、具体的提示词
- 错误处理:处理API限制、网络错误等异常情况
- 成本控制:监控API调用次数和token使用量
- 数据隐私:确保敏感数据的安全处理
- 缓存策略:缓存常见查询的结果
- 模型选择:根据任务复杂度选择合适的模型
- 性能监控:监控响应时间和成功率
总结
Spring AI为Java开发者提供了一个强大而易用的AI集成框架。通过统一的抽象层,开发者可以轻松地在Spring应用中集成各种AI服务,构建智能化的应用程序。随着AI技术的快速发展,Spring AI将继续演进,为企业级AI应用开发提供更多支持。