Skip to content

combine aws sagemaker endpoint with extraction function? #119

@MaxS3552284

Description

@MaxS3552284

Hello, I tried to combine the langchain extraction functions, (from https://python.langchain.com/docs/use_cases/extraction/quickstart,) with an mistral-7b-instruct endpoint running on aws sagemaker.

I replaced llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
as described here: https://python.langchain.com/docs/use_cases/extraction/quickstart

with

llm = SagemakerEndpoint(
    endpoint_name=MISTRAL_ENDPOINT,
    region_name=AWS_REGION,
    content_handler=ContentHandler_mistral7B(),
    callbacks=[StreamingStdOutCallbackHandler()],
    endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
)

class ContentHandler_mistral7B(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        
        input_user = [
                        {
                            "role": "user",
                            "content": prompt,
                        }
        ]
 
        prompt_formated = format_instructions(input_user)
        payload = {
            "inputs": prompt_formated,
            "parameters": {"max_new_tokens": 256, "do_sample": True, "temperature": 0.1}
        }
        
        input_str = json.dumps(
            payload,
            ensure_ascii=False,
        )
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        content = response_json[0]["generated_text"]

        return content

def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    # from sagemaker notebook
    """Format instructions where conversation roles must alternate user/assistant/user/assistant/..."""
    prompt: List[str] = [] # prompt is supposed to be a list of strings
    for user, answer in zip(instructions[::2], instructions[1::2]):
        prompt.extend(["<s>", "[INST] ", (user["content"]).strip(), " [/INST] ", (answer["content"]).strip(), "</s>"])
    prompt.extend(["<s>", "[INST] ", (instructions[-1]["content"]).strip(), " [/INST] "])
    return "".join(prompt) # join list into single string

which resulted in an error message "LangChainBetaWarning: The function with_structured_output is in beta. It is actively being worked on, so the API may change.
warn_beta(

NotImplementedError"

when executing runnable = prompt | llm.with_structured_output(schema=Person)

I suspect that i probably messed up with the input text transformation within the contenthandler, but cant really figure out how

can you help me?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions