基于Azure实现Java访问OpenAI

1、前言

        之前使用了Java代码访问OpenAI:OpenAI注册以及Java代码调用_雨欲语的博客-CSDN博客但是需要vpn才能访问,现在可以基于微软的Azure访问OpenAI,不再需要vpn,官方文档:快速入门 - 开始通过 Azure OpenAI 服务使用 ChatGPT 和 GPT-4 - Azure OpenAI Service | Microsoft Learn,官方对Python和C#进行了封装,java没有,但是可以通过uri的方式进行访问。

        Azure申请:什么是 Azure OpenAI 服务? - Azure Cognitive Services | Microsoft Learn

2、返回结果封装

       首先根据返回结果可以封装一些java类:

        AzureAIChatResponse类:

public class AzureAIChatResponse {
    private String id;
    private String object;
    private String created;
    private String model;
    private AzureAIUsage usage;
    private List<AzureAIChoice> choices;


    public String getId() {
        return id;
    }


    public void setId(String id) {
        this.id = id;
    }


    public String getObject() {
        return object;
    }


    public void setObject(String object) {
        this.object = object;
    }


    public String getCreated() {
        return created;
    }


    public void setCreated(String created) {
        this.created = created;
    }


    public String getModel() {
        return model;
    }


    public void setModel(String model) {
        this.model = model;
    }


    public AzureAIUsage getUsage() {
        return usage;
    }


    public void setUsage(AzureAIUsage usage) {
        this.usage = usage;
    }


    public List<AzureAIChoice> getChoices() {
        return choices;
    }


    public void setChoices(List<AzureAIChoice> choices) {
        this.choices = choices;
    }
    
}

         AzureAIUsage类:

public class AzureAIUsage {
    /*
     "prompt_tokens": 10,
        "completion_tokens": 9,
        "total_tokens": 19
    */
    @SerializedName("prompt_tokens")
    private int promptTokens;
    @SerializedName("completion_tokens")
    private int completionTokens;
    @SerializedName("total_tokens")
    private int totalTokens;

    public int getPromptTokens() {
        return promptTokens;
    }

    public void setPromptTokens(int promptTokens) {
        this.promptTokens = promptTokens;
    }


    public int getCompletionTokens() {
        return completionTokens;
    }


    public void setCompletionTokens(int completionTokens) {
        this.completionTokens = completionTokens;
    }


    public int getTotalTokens() {
        return totalTokens;
    }

    public void setTotalTokens(int totalTokens) {
        this.totalTokens = totalTokens;
    }
    
}

        AzureAIChoice类:

public class AzureAIChoice {
    private Message message;
}

        AzureAIMessage类:

public class AzureAIMessage {
    private String role;
    private String content;
}

3、参数封装

        根据参数封装类:

        AzureAIChatRequest类:

public class AzureAIChatRequest {
    private List<AzureAIMessage> messages;
    private Double temperature;
    @SerializedName("n")
    private Integer choices;
    private boolean stream;
    private String stop;
    @SerializedName("max_tokens")
    private Integer maxTokens;
    @SerializedName("presence_penalty")
    private Integer presencePenalty;
    @SerializedName("frequency_penalty")
    private Integer frequencyPenalty;
    private String user;


    public List<AzureAIMessage> getMessages() {
        return messages;
    }

 
    public void setMessages(List<AzureAIMessage> messages) {
        this.messages = messages;
    }

    public void addMessage(AzureAIMessage message) {
        if (this.messages == null) {
            this.messages = new ArrayList<>();
        }
        this.messages.add(message);
    }

    public Double getTemperature() {
        return temperature;
    }


    public void setTemperature(Double temperature) {
        this.temperature = temperature;
    }


    public int getChoices() {
        return choices;
    }


    public void setChoices(int choices) {
        this.choices = choices;
    }


    public boolean isStream() {
        return stream;
    }


    public void setStream(boolean stream) {
        this.stream = stream;
    }


    public String isStop() {
        return stop;
    }

    public void setStop(String stop) {
        this.stop = stop;
    }

    public void setStop(boolean stop) {
        if (stop) {
            this.stop = "true";
        } else {
            this.stop = "false";
        }
    }

    public int getMaxTokens() {
        return maxTokens;
    }


    public void setMaxTokens(int maxTokens) {
        this.maxTokens = maxTokens;
    }


    public int getPresencePenalty() {
        return presencePenalty;
    }


    public void setPresencePenalty(int presencePenalty) {
        this.presencePenalty = presencePenalty;
    }


    public int getFrequencyPenalty() {
        return frequencyPenalty;
    }


    public void setFrequencyPenalty(int frequencyPenalty) {
        this.frequencyPenalty = frequencyPenalty;
    }


    public String getUser() {
        return user;
    }


    public void setUser(String user) {
        this.user = user;
    }
}

         AzureAIMessage类:

public class AzureAIMessage {
    
    private AzureAIRole role;
    private String content;


    public AzureAIMessage() {
        
    }
    public AzureAIMessage(String content, AzureAIRole role) {
        this.content = content;
        this.role = role;
    }
    

    public AzureAIRole getRole() {
        return role;
    }


    public void setRole(AzureAIRole role) {
        this.role = role;
    }


    public String getContent() {
        return content;
    }

 
    public void setContent(String content) {
        this.content = content;
    }
    
}

        AzureAIRole类:

public enum AzureAIRole {
    
    @SerializedName("assistant")
    ASSISTANT("assistant"),
    @SerializedName("system")
    SYSTEM("system"), 
    @SerializedName("user")
    USER("user"), 
    
   ;
    
    private final String text;


    private AzureAIRole(final String text) {
        this.text = text;
    }


    @Override
    public String toString() {
        return text;
    }

    
}

4、客户端访问

        客户端访问类:

import cn.hutool.core.date.BetweenFormatter;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.json.JSONUtil;
import com.google.gson.Gson;
import lombok.extern.slf4j.Slf4j;
import org.asynchttpclient.*;

import java.io.Closeable;
import java.io.IOException;
import java.util.Date;
import java.util.concurrent.Future;


@Slf4j
public class AzureAIClient implements Closeable {

    private static final String JSON = "application/json; charset=UTF-8";
    private final boolean closeClient;
    private final AsyncHttpClient client;
    private final String deploymentName;
    private final String url;
    private final String token;
    private static final Version version = new Version();
    private final String apiVersion;
    private boolean closed = false;
    Gson gson = new Gson();

    public AzureAIClient(String url, String apiKey, String deploymentName, String apiVersion) throws Exception {
        this.client = new DefaultAsyncHttpClient();

        this.url = url + "/openai/deployments/" + deploymentName + "/";
        this.token = apiKey;
        this.deploymentName = deploymentName;
        this.apiVersion = apiVersion;
        closeClient = true;
    }


    public boolean isClosed() {
        return closed || client.isClosed();
    }

    @Override
    public void close() {
        if (closeClient && !client.isClosed()) {
            try {
                client.close();
            } catch (IOException ex) {

            }
        }
        closed = true;
    }

    public static String getVersion() {
        return version.getBuildNumber();
    }

    public static String getBuildName() {
        return version.getBuildName();
    }

    public AzureAICompletionsResult getCompletion(AzureAICompletionRequest completion) throws Exception {
        //chat/completions
        Future<Response> f = client.executeRequest(buildRequest("POST", "completions?api-version=" + apiVersion, gson.toJson(completion)));
        Response r = f.get();
        if (r.getStatusCode() != 200) {

            throw new Exception("Could not get competion result");
        } else {
            return gson.fromJson(r.getResponseBody(), AzureAICompletionsResult.class);

        }
    }

    public AzureAICreateEmbedingResponse createEmbedding(AzureAIEmbedding embedding) throws Exception {
        Future<Response> f = client.executeRequest(buildRequest("POST", "embeddings?api-version=" + apiVersion, gson.toJson(embedding)));
        Response r = f.get();
        if (r.getStatusCode() != 200) {

            throw new Exception("Could not create embedding");
        } else {
            AzureAICreateEmbedingResponse azureAICreateEmbedingResponse =  JSONUtil.toBean(r.getResponseBody(), AzureAICreateEmbedingResponse.class);
            return azureAICreateEmbedingResponse;

        }
    }

    public AzureAIChatResponse sendMyChatRequest(AzureAIChatRequest chatRequest) throws Exception {
        Date startDateOne = DateUtil.date();
        String f = buildMyRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest));
        Date endDateOne = DateUtil.date();
        // 获取开始时间和结束时间的时间差
        long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
        // 格式化时间
        String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
        log.info(String.format("请求数据耗时(毫秒):%s",formatBetweenOne));
        Date startDate = DateUtil.date();
        System.err.println(f);
        AzureAIChatResponse azureAIChatResponse = gson.fromJson(f, AzureAIChatResponse.class);
        Date endDate = DateUtil.date();
        // 获取开始时间和结束时间的时间差
        long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
        // 格式化时间
        String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
        log.info(String.format("格式化数据耗时(毫秒):%s",formatBetween));
        return azureAIChatResponse;
    }

    private String buildMyRequest(String type, String subUrl, String requestBody) {
//        RestTemplate restTemplate = new RestTemplate();
//        HttpHeaders httpHeaders = new HttpHeaders();
//        // 设置contentType
        httpHeaders.setContentType(MediaType.APPLICATION_JSON_UTF8);
//        // 给请求header中添加一些数据
//        httpHeaders.add("Accept", JSON);
//        httpHeaders.add("Content-Type", JSON);
//        httpHeaders.add("api-key", this.token);
//
//
//        HttpEntity<String> httpEntity = new HttpEntity<String>(requestBody, httpHeaders);
//        ResponseEntity<String> exchange = restTemplate.postForEntity(this.url + subUrl, httpEntity, String.class);
//
//        String resultRemote = exchange.getBody();//得到返回的值


        String accept = HttpRequest.post(this.url + subUrl)
                .header("Accept", JSON)
                .header("Content-Type", "application/json")
                .header("api-key", this.token)
                .setReadTimeout(30000)
                .body(requestBody)
                .execute()
                .body();
        return accept;
    }




    public AzureAIChatResponse sendChatRequest(AzureAIChatRequest chatRequest) throws Exception {
        Date startDateOne = DateUtil.date();
        Future<Response> f = client.executeRequest(buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest)));
//        Request r = buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest));
        Response r = f.get();
        Date endDateOne = DateUtil.date();
        // 获取开始时间和结束时间的时间差
        long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
        // 格式化时间
        String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
        log.info(String.format("请求数据耗时(毫秒):%s",formatBetweenOne));
        if (r.getStatusCode() != 200) {

            log.info("Could not create chat request - server resposne was " + r.getStatusCode() + " to url: " + url + "chat/completions?api-version=2023-03-15-preview");
            return null;
        } else {
            Date startDate = DateUtil.date();
//            System.err.println(r.getResponseBody());
            AzureAIChatResponse azureAIChatResponse =  JSONUtil.toBean(r.getResponseBody(), AzureAIChatResponse.class);
//            AzureAIChatResponse azureAIChatResponse = gson.fromJson(r.getResponseBody(), AzureAIChatResponse.class);
            Date endDate = DateUtil.date();
            // 获取开始时间和结束时间的时间差
            long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
            // 格式化时间
            String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
            log.info(String.format("格式化数据耗时(毫秒):%s",formatBetween));
            return azureAIChatResponse;

        }
    }

    private Request buildRequest(String type, String subUrl) {
        RequestBuilder builder = new RequestBuilder(type);
        Request request = builder.setUrl(this.url + subUrl)
                .addHeader("Accept", JSON)
                .addHeader("Content-Type", JSON)
                .addHeader("Authorization", "Bearer " + this.token)
                .build();
        return request;
    }

    private Request buildRequest(String type, String subUrl, String requestBody) {
        RequestBuilder builder = new RequestBuilder(type);
        Request request = builder.setUrl(this.url + subUrl)
                .addHeader("Accept", JSON)
                .addHeader("Content-Type", JSON)
                .addHeader("api-key", this.token)
                .setBody(requestBody)
                .build();
        return request;
    }

}

5、调用测试

        调用测试:

public static void main(String[] args) {

    // 装配请求集合
    List<AzureAIMessage> azureAiMessageList = new ArrayList<>();
    AzureAIChatRequest azureAiChatRequest = new AzureAIChatRequest();

    AzureAIMessage azureAIMessage0 = new AzureAIMessage();
    azureAIMessage0.setRole(AzureAIRole.SYSTEM);
    azureAIMessage0.setContent("你是一个AI机器人,请根据提问进行回答");

    azureAiMessageList.add(azureAIMessage0);
    AzureAIMessage azureAIMessage1 = new AzureAIMessage();
    azureAIMessage1.setRole(AzureAIRole.USER);
    azureAIMessage1.setContent("请解释一下java的gc");
    azureAiMessageList.add(azureAIMessage1);

    azureAiChatRequest.setMessages(azureAiMessageList);
    azureAiChatRequest.setMaxTokens(maxTokens);
    azureAiChatRequest.setTemperature(temperature);
    // 是否进行留式返回
//        azureAiChatRequest.setStream(true);
    azureAiChatRequest.setPresencePenalty(0);
    azureAiChatRequest.setFrequencyPenalty(0);
    azureAiChatRequest.setStop(null);
    
AzureAIClient azureAIClient = new AzureAIClient("申请的azure地址", "zaure的apikey", 
"模型(gpt-35-turbo)", "api版本:(023-03-15-preview)");
    AzureAIChatResponse azureAIChatResponse = azureAIClient.sendChatRequest(azureAIChatRequest);
    
}
        

6、依赖

maven依赖:

<dependencies>
        
    <dependency>
        <groupId>org.asynchttpclient</groupId>
        <artifactId>async-http-client</artifactId>
        <version>2.12.3</version>
    <type>jar</type>
    </dependency>

    <dependency>
        <groupId>com.google.code.gson</groupId>
        <artifactId>gson</artifactId>
        <version>2.10.1</version>
    </dependency>
</dependencies>

猜你喜欢

转载自blog.csdn.net/qq_41061437/article/details/130927618