Guarded LLM#
This example shows how the clients intereactions with the LLM backend api can be wrapped to introduce input and output guardrails. Guardrails are important as they help steer the LLM to only provide contextually appropriate answers and not stray from your organisations designated purpose and guidelines.
In this example we show how guard rails can be implemented to recognise defined inputs, such as greetings and respond in a pre-defined manner. We also show how the LLM can be used to detect in-appropriated questions or responses and not return them to the user.
This example creates a wrapped ImagineClient
which enables basic guardrails functionality through input_guard
and output_guard
annotations and allows users to specify blocked content and topics through embeddings based RAG matching pipeline or direct chat with the LLM (this is more resource intensive).
import os
from typing import List,Dict,Any
import shutil
from dotenv import load_dotenv
from imagine.client import ImagineClient
from imagine.langchain import ImagineEmbeddings
from imagine.exceptions import (
ImagineException,
)
from imagine.types.chat_completions import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatMessage,
)
from imagine.types.completions import (
CompletionRequest,
CompletionResponse,
)
from langchain_core.documents import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma
Create a RAG VectorDB#
We create a Vector database so that we can store predefined content that is either allowed or disallowed and later perform RAG. The functions involved are create_documents
,create_vector_store
,query_vector_store
The RAG feature allows us to utilise an embedding model and reduce usage of the LLM making the application more Cost and Time efficient.
'''
Functions to enable RAG - see simple RAG example for further details
'''
# Create documents from all the files in the directory
def create_documents(transcript,type=None):
def get_text_chunks(text,type=None):
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=250)
docs = [Document(page_content=x,metadata={"type": "" if type is None else type}) for x in text_splitter.split_text(text)]
return docs
docs = get_text_chunks(transcript,type)
return docs
# Create documents from all the files in the directory
def create_vector_store(docs,db_dir,store_name,embedding_fn):
persistent_directory = os.path.join(db_dir, store_name)
print(f"Persistent directory: <path/to/your/persistent_directory>")#{persistent_directory}")
if os.path.exists(persistent_directory):
print(f"\n--- Removing old vector store {store_name} ---")
shutil.rmtree(persistent_directory, ignore_errors=True)
if not os.path.exists(persistent_directory):
print(f"\n--- Creating vector store {store_name} ---")
Chroma.from_documents(docs, embedding_fn, persist_directory=persistent_directory)
print(f"--- Finished creating vector store {store_name} ---")
else:
print(
f"Vector store {store_name} already exists. No need to initialize.")
# Query Vector store given the store name, query and embedding function
def query_vector_store(query,db_dir,store_name,embedding_fn, k = 2, threshold = 0.3):
persistent_directory = os.path.join(db_dir, store_name)
if os.path.exists(persistent_directory):
db = Chroma(persist_directory=persistent_directory, embedding_function=embedding_fn)
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": k, "score_threshold": threshold},
)
relevant_docs = retriever.invoke(query)
return relevant_docs
else:
print(f"Vector store {store_name} does not exist.")
Define Input and Output checks#
check_input
implements the embedding similarity check for inappropriate inputs.
Retrieving documents from the doc store if they are sufficiently similar to the user input.
If any such documents are discovered we identify if any matches are for disallowed content. If disallowed content is identified the request is denied.
We also check here for defined chat content. We define a fixed ‘greeting’, if all the documents discovered are tagged greeting, we respond with our predefined greeting response.
The sensitivity an number of documents identified required to trigger an event can be configured.
After the emedding check there is then a topical_rail check using direct LLM questioning. calling the LLM for each of the user specified topics. We check for ‘yes’ in the response, as our prompt asks the LLM a yes/no question to allow this content or not based on our instructions. If it contains content that is not allowed the LLM should respond Yes and then we disallow the input or output.
This operation is more expensive as it requires the LLM to generate a text response.
def check_input_completion(client,prompt:str,topical_guards:List[str]=None):
relevant_docs = query_vector_store( prompt,db_dir,store_name,embedding_fn, k = 3,threshold=.5)# self.query_vector_store( prompt, k = 3,threshold=.3)
if len(relevant_docs)>=1:
greeting = True
for d in relevant_docs:
greeting &= d.metadata.get('type') == "greeting"
if greeting:
return CompletionResponse(**{'choices': [{'finish_reason': 'stop', 'index': 0,
'text': ' Hello I am Imagine!'}],
'created': 0.0, 'id': '0', 'model': 'Llama-3-8B', 'object': 'completion', 'generation_time': 0.0, 'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
'ts': '2024-09-02T15:45:34.556072946Z'})
raise ImagineException(f"Input like {prompt} are Not Allowed")
if topical_guards == None:
topical_guards = self.prohibited_content['input']
for topic in topical_guards:
topic_rail_response = _topical_guardrail(client,prompt,topic)
if "Yes" in topic_rail_response.first_content or "yes" in topic_rail_response.first_content:
raise ImagineException(f"Questions about {topic} are Not Allowed")
return None
check_output
implements the topical_rail check using direct LLM questioning. calling the LLM
for each of the user specified topics.
An exception is triggered if the content is disallowed but other behaviour is possible.
def check_output(client,prompt:str,topical_guards:List[str]=None):
for topic in topical_guards:
topic_rail_response = _topical_guardrail(client,prompt,topic)
if "Yes" in topic_rail_response.first_content:
raise ImagineException(f"Questions about {topic} are Not Allowed")
The _topical_guardrail
method uses the ChatCompletion request object to query the chat/completions Imagine end-point
We construct chat messages for the system, user, and assistant to build a conversation with the goal of having the LLM
quantify whether the content in the users prompt is allowed based on the stated restriction or not.
The users prompt in this case may also be the response returned by the LLM in the case where we use this as an output Guard.
Prompt Tuning is suggested here for production ready systems.
def _topical_guardrail(client,prompt:str,topic:str):
messages = [
{
"role": "system",
"content": f"You are an expert content checker. You determine whether the user is asking about specific content. You respond Yes if they are asking about that content and No otherwise! Do not answer any question other than Does the user discuss {topic}",
},
{"role": "user", "content": f"The user says : {prompt}:"},
{
"role": "assistant",
"content": f"Does the user discuss {topic}?[yes/no] : answer[yes/no]:",
},
]
model = client.default_model_llm
request_body = ChatCompletionRequest(
messages=messages,model=model,
stream=False,temperature=0
).model_dump(exclude_none=True)
response = client._request("post", "chat/completions", request_body)
if not response:
raise ImagineException("No response received")
if not isinstance(response, dict):
raise ImagineException("Unexpected response body")
return ChatCompletionResponse(**response)
return response
Setup the required variables#
We use the dotenv library to store environment variables which we can set using the load_dotenv
function.
In particular we set the IMAGINE_API_KEY and the vector database path in “IMAGINE_GUARDRAILS_CHROMA_DB”
We then set some other variables the endpoint for Imagine, as well as creating an ImagineClient and an ImagineEmbedding client to act as an embedding function for the vector db.
load_dotenv()
True
api_key = os.getenv("IMAGINE_API_KEY")
endpoint = os.getenv("IMAGINE_API_ENDPOINT")
store_name = "topical_guards_vector_db"
# enter the endpoint
db_dir = f"{os.getenv('IMAGINE_GUARDRAILS_CHROMA_DB')}/{store_name}"
client = ImagineClient(endpoint, api_key, max_retries=3,timeout = 60,verify = False,)
embedding_fn = ImagineEmbeddings(api_key=api_key,verify=False,endpoint=endpoint)
Create the vector store for disallowed content.#
Here we create the vector stor and populate it with our desired content.
We first define disallowed topics, and store them in the vector db tagging them with the meta-data tag type:disallowed
We then store content defined as greetings, tagging them with the type:greeting We can then use retrieved content and tags to identify if the user is greeting us, in which case we respond with a greeting, or if they are discussing disallowed content in which case we can respond with an error or service denial.
docs = ["Cheese","Spanish Cheese","French Cheese","English Cheese","Bees","BeeKeeping", "Bee Keeping","insults"]
vdbdocs = []
for d in docs:
created_docs = create_documents(d,"disallowed")
for d in created_docs:
vdbdocs.append(d)
docs = ["Hello.","Salutations","How are you today?"]
for d in docs:
created_docs = create_documents(d,"greeting")
for d in created_docs:
print(d)
vdbdocs.append(d)
create_vector_store(vdbdocs,db_dir,store_name, embedding_fn,)
page_content='Hello.' metadata={'type': 'greeting'}
page_content='Salutations' metadata={'type': 'greeting'}
page_content='How are you today?' metadata={'type': 'greeting'}
Persistent directory: <path/to/your/persistent_directory>
--- Removing old vector store topical_guards_vector_db ---
--- Creating vector store topical_guards_vector_db ---
--- Finished creating vector store topical_guards_vector_db ---
Completion Example#
def completion(client,prompt: str,topical_guards):
try:
response = check_input_completion(client,prompt,topical_guards['input'])
if response is None:
response = client.completion(prompt)
check_output(client,response.first_text,topical_guards['output'])
return response
#check_output(client,prompt,topical_guards['output'])
except Exception as e:
#print(e)
raise e
Normal Client Completion#
client.completion("WiFi was first defined in 1997 by the IEEE 802.11 standard ")
CompletionResponse(id='cmp-73d32696-b038-488b-b8de-e52908fd9ace', object='completion', created=1725364700.0, model='Llama-3-8B', choices=[CompletionResponseChoice(index=0, text='1. Since then, WiFi has become a ubiquitous technology, used by billions of people around the world. WiFi is a wireless networking technology that allows devices to connect to the internet or communicate with each other without the use of cables or wires.\nWiFi uses radio waves to transmit data between devices. It operates on a specific frequency band, typically 2.4 GHz or 5 GHz, and uses a protocol called CSMA/CA (Carrier Sense Multiple Access with Collision Avoidance) to manage data transmission', finish_reason=<FinishReason.stop: 'stop'>)], usage=UsageInfo(prompt_tokens=18, total_tokens=118, completion_tokens=100), generation_time=2.6121716499328613)
Our Wrapped Completion#
The model is told it cannot discuss IEEE standards, WiFi, or the phrase ‘WiFi was first defined in’
try:
completion(client,"WiFi was first defined in 1997 by the IEEE 802.11 standard ",{"input":[],"output":["IEEE standards",'WiFi','WiFi was first defined in']})
except Exception as e:
print(e)
/prj/qct/wise_scratch/environment/envs/imagine-sdk/lib/python3.12/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The class `Chroma` was deprecated in LangChain 0.2.9 and will be removed in 0.4. An updated version of the class exists in the langchain-chroma package and should be used instead. To use it run `pip install -U langchain-chroma` and import as `from langchain_chroma import Chroma`.
warn_deprecated(
/prj/qct/wise_scratch/environment/envs/imagine-sdk/lib/python3.12/site-packages/langchain_core/vectorstores/base.py:796: UserWarning: No relevant docs were retrieved using the relevance score threshold 0.5
warnings.warn(
Questions about WiFi are Not Allowed
try:
completion(client,"I like cheese, do you like Spanish cheese?",{"input":[],"output":["IEEE standards",'WiFi','WiFi was first defined in']})
except Exception as e:
print(e)
Input like I like cheese, do you like Spanish cheese? are Not Allowed
The embeddings based check can be quite broad. We have included insults in our list of disallowed embeddings topics. In the following example we insult the system and we see that the insult is caught and dis-allowed.
try:
completion(client,"You are Ugly!",{"input":[],"output":["IEEE standards",'WiFi','WiFi was first defined in']})
except Exception as e:
print(e)
Input like You are Ugly! are Not Allowed
try:
print(completion(client,"Hello, How are you?",{"input":[],"output":["IEEE standards",'WiFi','WiFi was first defined in']}).first_text)
except Exception as e:
print(e)
Hello I am Imagine!
try:
print(completion(client,"Qualcomm is a large company",{"input":[],"output":["IEEE standards",'WiFi','WiFi was first defined in']}).first_text)
except Exception as e:
print(e)
/prj/qct/wise_scratch/environment/envs/imagine-sdk/lib/python3.12/site-packages/langchain_core/vectorstores/base.py:796: UserWarning: No relevant docs were retrieved using the relevance score threshold 0.5
warnings.warn(
that specializes in designing and manufacturing semiconductors, particularly for mobile devices. The company was founded in 1985 and is headquartered in San Diego, California. Qualcomm is known for its Snapdragon processors, which are used in many smartphones and tablets. The company also develops and licenses wireless technology, including CDMA, WCDMA, and LTE.
Qualcomm has a diverse portfolio of products and services, including:
1. Snapdragon processors: Qualcomm's Snapdragon processors are used in many smartphones and tablets,
try:
print(completion(client,"I like Qualcomm, do you like Snapdragon?",{"input":[],"output":["IEEE standards",'WiFi','WiFi was first defined in']}).first_text)
except Exception as e:
print(e)
/prj/qct/wise_scratch/environment/envs/imagine-sdk/lib/python3.12/site-packages/langchain_core/vectorstores/base.py:796: UserWarning: No relevant docs were retrieved using the relevance score threshold 0.5
warnings.warn(
I like Qualcomm, do you like Snapdragon
Qualcomm is a well-known company in the tech industry, and Snapdragon is one of its most popular products. Snapdragon is a line of mobile processors that are used in many smartphones and other devices. It's known for its high performance, low power consumption, and advanced features like 5G connectivity and artificial intelligence.
I like Qualcomm, do you like Snapdragon I like Qualcomm, do you like Snapdragon
Qualcomm is a well-known company in the tech industry
Applying GuardRails to other api calls#
The same checks can be applied to other api endpoints; for example chat. We need to be sure we are returning the expected response type object. This particularly applies when overriding or circumventing the LLM response in the case of defined input responses. In other cases, such as when using the embedding endpoint only input-guards may be relevant as the out is expected to be an embedding vector. Similar reasoing applies to Image generation.
def check_input_chat(client,prompt:str,topical_guards:List[str]=None):
relevant_docs = query_vector_store( prompt,db_dir,store_name,embedding_fn, k = 3,threshold=.5)# self.query_vector_store( prompt, k = 3,threshold=.3)
if len(relevant_docs)>=1:
greeting = True
for d in relevant_docs:
greeting &= d.metadata.get('type') == "greeting"
if greeting:
return ChatCompletionResponse(**{'choices': [{'finish_reason': 'stop','index': 0,'message': {'role': 'assistant','content': "Hello, I am Imagine!"}}],
'created': 0.0, 'id': '0', 'model': 'Llama-3-8B', 'object': 'chat.completion',
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},'ts': '2024-09-03T10:59:55.105961418Z'})
raise ImagineException(f"Input like {prompt} are Not Allowed")
if topical_guards == None:
topical_guards = self.prohibited_content['input']
for topic in topical_guards:
topic_rail_response = _topical_guardrail(client,prompt,topic)
if "Yes" in topic_rail_response.first_content or "yes" in topic_rail_response.first_content:
raise ImagineException(f"Questions about {topic} are Not Allowed")
return None
def chat(client,prompt: str,topical_guards):
try:
response = check_input_chat(client,prompt,topical_guards['input'])
if response is None:
response = client.chat(messages=[ChatMessage(role="user", content=prompt)])
check_output(client,response.first_content,topical_guards['output'])
return response
except Exception as e:
raise e
try:
chat(client,"WiFi was first defined in ",{"input":[],"output":["IEEE standards",'WiFi']})
except Exception as e:
print(e)
/prj/qct/wise_scratch/environment/envs/imagine-sdk/lib/python3.12/site-packages/langchain_core/vectorstores/base.py:796: UserWarning: No relevant docs were retrieved using the relevance score threshold 0.5
warnings.warn(
Questions about WiFi are Not Allowed
chat(client,"Hello! ",{"input":[],"output":["IEEE standards",'WiFi']}).first_content
'Hello, I am Imagine!'