How to create a watsonx.ai REST client in Spring Boot?

This blog post is an example approach to invocate a watsonx.ai endpoint in Java Spring Boot.

The example contains six classes to distribute the responsibilities to get the relevant content for a question from an answer given by a model in watsonx.ai. The following table includes these classes with a short description.

ClassDescription
WatsonxModelEndpointThis class sends the concrete request with full prompt and model parameters to watsonx.ai and changes the answer if needed.
WatsonxPromptHandlerThis class handles the prompt and replaces a variable part in a prompt definition.
PromptTemplateFileInputThis class loads a template text for the prompt from a file.
WatsonxAnswerDataThis class represents the return value for the REST endpoint.
ExampleApplicationThis class represents the application.

The following diagram shows the simplified dependencies for the example.

The following image shows a sequence of how to invoke watsonx.ai to get an answer.

Here is the description of the steps above in the simplified sequence diagram:

  1. WatsonxModelEndpoint/askQuestion:
    Here, we provide the question and the context related to the question.
  2. WatsonxModelEndpoint/getAnswer:
    We implement the request to the watsonx model, and we customize the returned answer.
  3. IBMCloudTokenEndpoint/getToken:
    Here, we get the needed IAM access token to invoke the watsonx.ai endpoint for to get an answer from the model.
  4. WatsonxPromptHandler/getPrompt:
    Here, we load the prompt template as a string, replace the text for the question and context in the prompt template, and return the entire prompt.
  5. PromptTemplateFileInputHandler/getPromptTemplate:
    Here, we load the text from the prompt template file.
  6. watsonx.ai/generatetext:
    Provides a complete answer.
  7. WatonxAnswerHandler/extract needed text:
    We may want to extract particular parts of an answer from the given response of the model hosted in watsonx.ai.

1. The source code

Here is a starting point from watsonx.ai as a curl command on how to invoke the “text/generation” endpoint:

curl "https://XXXX.ml.cloud.ibm.com/ml/v1/text/generation?version=2023-05-29" \
  -H 'Content-Type: application/json' \
  -H 'Accept: application/json' \
  -H 'Authorization: Bearer YOUR_ACCESS_TOKEN' \
  -d '{
 "input": "PROMPT CONTENT",
 "parameters": {
  "decoding_method": "greedy",
  "max_new_tokens": 500,
  "min_new_tokens": 0,
  "stop_sequences": [],
  "repetition_penalty": 1
 },
 "model_id": "MODEL_ID",
 "project_id": "PROJECT_ID"
}'

You can get the command in the watsonx.ai prompt lab, as you can see in the give below.

2. WatsonxModelEndpoint

We use the Web Client RestTemplate to implement the Spring Boot application. These are the steps which are implemented in the code below:

  1. Building the header
  2. Building the payload JSON
  3. Building the request for the watsonx.ai endpoint
  4. Send the request to the watsonx.ai endpoint
package example.external_endpoints;

// Related classes
import example.demo.data_formats.WatsonxAnswerData;
import example.prompts.WatsonxPromptHandler;
import example.answer_handlers.WatsonxModelAnswerHandler;

import java.util.List;

// REST Endpoint handling
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Scope;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;

// JSON handling
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

// Logging
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WatsonxModelEndpoint {
    private static final Logger log = LoggerFactory.getLogger(IBMCloudTokenEndpoint.class);
 
    private static final String version = "2023-05-29";
    private static final String url = "https://XXX.ml.cloud.ibm.com/ml/v1/text/generation?version=" + version;
    private static final String project_id = "XXXX";
    private static final String model_id = "MODEL";
    private static final String decoding_method = "greedy";
    private static final int min_new_tokens = 1;
    private static final int max_new_tokens = 300;
    private static final double beam_width = 1.0;
    private static final String stop_sequence = "[]";
 
    @Autowired
    RestTemplate restTemplate;

    public WatsonxAnswerData askQuestion(String question, String context) {
        WatsonxAnswerData watsonxAnswer = new WatsonxAnswerData();
        watsonxAnswer.setAnswer(getAnswer(question,context));
        return watsonxAnswer;
    }

    private String getAnswer(String question, String context){
        
        System.out.println("**Log: getAnswer");
        String answerString = null;
        String responseString = null;
        JSONObject watsonxRequestJSON = null;
        JSONObject watsonxResponseJSON = null;
        
        // 1. Building the header
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        List<MediaType> mediaTypes = List.of(MediaType.APPLICATION_JSON);
        headers.setAccept(mediaTypes);
        headers.setBearerAuth(getIamTokenString());
        //log.info("Log: header:\n" + headers.toString());

        // 2. Building the payload JSON
        /* Request format
         {"input": prompt,
          "parameters":{"decoding_method":"greedy", 
                        "max_new_tokens: 500,
                        "min_new_tokens": 0,
                        "stop_sequence":[],
                        "repetition_penalty": 1.0 },
          "model_id":"MODEL_ID",
          "project_id":"PROJECT_ID"
         }
         */
        watsonxRequestJSON = new JSONObject();
        watsonxRequestJSON.put("input", getPrompt(question,context));
            JSONObject parameters = new JSONObject();
            parameters.put("decoding_method", decoding_method);
            parameters.put("max_new_tokens", max_new_tokens);
            parameters.put("min_new_tokens", min_new_tokens);
            parameters.put("beam_width", beam_width);
            parameters.put("stop_sequence", stop_sequence);
        watsonxRequestJSON.put("parameters", parameters);
        watsonxRequestJSON.put("model_id", model_id);
        watsonxRequestJSON.put("project_id", project_id);

        // 3. Building the request for the watsonx.ai endpoint
        HttpEntity<String> request = new HttpEntity<>(watsonxRequestJSON.toString(), headers);
        RestTemplate restTemplate = new RestTemplate();
        
        try {
            log.info("Log: invoke watsonx:\n" + request.toString());
            // 4. Send the request to the watsonx.ai endpoint
            ResponseEntity<String> response = restTemplate.postForEntity(url, request,String.class);

            if( response.getStatusCode().value() == 200 ){
                responseString = response.getBody();
                try {
                    watsonxResponseJSON = new JSONObject(responseString);
                    log.info("Log: " + watsonxResponseJSON.toString());
                    /* Example response format
                    { "model_id":"MODEL_ID",
                      "created_at":"DATE,
                      "results":[{"generated_text":"",
                                  "generated_token_count":184,
                                  "input_token_count":63,
                                  "stop_reason":"eos_token"}],
                      "system":{"warnings":[{"message":"Info",
                                             "id":"disclaimer_warning",
                                             "more_info":"
                                             "}]
                                            }
                    }
                    */
                    JSONArray results = watsonxResponseJSON.getJSONArray("results");
                    JSONObject generated_text = results.getJSONObject(0);
                    String answerRAWString = (String) generated_text.get("generated_text");
                    WatsonxModelAnswerHandler watsonxModelAnswerHandler = new WatsonxModelAnswerHandler();
                    answerString = watsonxModelAnswerHandler.extractString(answerRAWString);                                     
                } catch (JSONException e) {
                    String error =  "JSON error: " + e;
                    log.info("Log:\n" + error);
                }

                return answerString; 

            } else {
                log.info("Log: " + response.getStatusCode().toString());
                return null; 
            }
        } catch (Exception e){
            log.info("Log: " + e.getMessage());
            return null;
        }
    }

    private String getPrompt(String question, String context){
        WatsonxPromptHandler watsonxPromptHandler = new WatsonxPromptHandler();

        String prompt = watsonxPromptHandler.getPrompt(question, context);
        return prompt;
    }

    private String getIamTokenString(){
        String response = IBMCloudTokenEndpoint.getIamToken();
            
        try {
            JSONObject respObject = new JSONObject(response);
            return respObject.getString("access_token");
        } catch (Exception e) {
            String error =  "getIamTokenString error: " + e;
            return error;
        }
    }
}

3. WatsonxPromptHandler

package example.prompts;

// Logger
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

// Related classes
import example.file_inputhandlers.PromptTemplateFileInputHandler;

public class WatsonxPromptHandler {
    private static final Logger log = LoggerFactory.getLogger(WatsonxPromptHandler.class);

    private static final String questionKeyword = "`{user_question}`";
    private static final String contextKeyword = "`{context_string}`";
    private static final String promptTemplateName = "Prompt_template.md";


    public String getPromptString (String question, String context){
        PromptTemplateFileInputHandler promptTemplate = new PromptTemplateFileInputHandler();

        String promptTemplateContent = promptTemplate.getPromptTemplate(promptTemplateName);
        promptTemplateContent = promptTemplateContent.replace(questionKeyword, question);
        promptTemplateContent = promptTemplateContent.replace(contextKeyword, context);
        
        log.info("Log: " + promptTemplateContent);
        return promptTemplateContent;
    }


    public String getRawPromptString (){
        PromptTemplateFileInputHandler promptTemplate = new PromptTemplateFileInputHandler();
        
        String promptTemplateContent = promptTemplate.getPromptTemplate(promptTemplateName1);
        log.info("Log: " + promptTemplateContent);

        return promptTemplateContent;      
    } 

}

4. WatsonxAnswerData

package example.data_formats;

public class WatsonxAnswerData {
 
    private String answer;

    public String getAnswer() {
        return answer;
    }

    public void setAnswer(String answer) {
        this.answer = answer;
    }

}

5. PromptTemplateFileInputHandler

package example.file_inputhandlers;

// Logging
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

// File system / IO
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.util.FileCopyUtils;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;

import example.ExampleApplication;

public class PromptTemplateFileInputHandler {
    private static final Logger log = LoggerFactory.getLogger(DemoApplication.class);

    public String getPromptTemplate(String promptTemplateName) {
           
        String location = "PATH_TO_FILE/" + promptTemplateName;
        Resource resource = new FileSystemResource(location);
        String content = "";

        try {
            InputStream inputStream = resource.getInputStream();
            try {
                byte[] bdata = FileCopyUtils.copyToByteArray(inputStream);
                content = new String(bdata, StandardCharsets.UTF_8);
                log.info("Log:\n" + content);
                return content;
            } catch (IOException e) {
                log.info("Error: PromptTemplateFileInputHandler\n(" + e.toString() + ")");
                return content;
            }
        } catch (Exception e) {
            log.info("Error: PromptTemplateFileInputHandler\n(" + e.toString() + ")");
            return content;
        }
    }
}

6. PromptTemplateFileInputHandler

package com.example.demo.answer_handlers;

public class WatsonxAnswerHandler {

    public String extractString(String answer){
        String sub_end = "\n```";

        String[] split_result = answer.split(sub_end);
        String extract = split_result[0];

        return extract;
    }
}

7. Related blog posts

These are the related blog posts to that topic:


I hope this was useful to you and let’s see what’s next?

Greetings,

Thomas

#watsonxai, #springboot, #java, #rest

3 thoughts on “How to create a watsonx.ai REST client in Spring Boot?

Add yours

  1. Thank you, the API key did bite me a little bit, the settings is not that intuitive. And I am able to invoke watsonx REST API successfully! For JSON stuff, spring boot has built-in support for it, just autowire ObjectMapper and use objectMapper.readValue(requestBody, Map.class), the requestBody can be Java 17 Text Block! So I can copy the entire json and convert it to Map in one line, and the RestTemplate also accept Map as request body. Also objectMapper can be used to parse json response.

    Liked by 1 person

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Blog at WordPress.com.

Up ↑