Skip to content
2 changes: 2 additions & 0 deletions guardrails/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def stream_validate(
)
if line:
json_output = json.loads(line)
if json_output.get("error"):
raise Exception(json_output.get("error").get("message"))
yield IValidationOutcome.from_dict(json_output)

def get_history(self, guard_name: str, call_id: str):
Expand Down
16 changes: 9 additions & 7 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,15 @@ async def _stream_server_call(
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)
# TODO re-enable this once we have a way to get history
# from a multi-node server
# if validation_output:
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend(
# [Call.from_interface(call) for call in guard_history]
# )
else:
raise ValueError("AsyncGuard does not have an api client!")

Expand Down
7 changes: 6 additions & 1 deletion guardrails/classes/llm/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def to_interface(self) -> ILLMResponse:
stream_output = [str(so) for so in copy_2]

async_stream_output = None
if self.async_stream_output:
# dont do this again if already aiter-able were updating
# ourselves here so in memory
# this can cause issues
if self.async_stream_output and not hasattr(
self.async_stream_output, "__aiter__"
):
# tee doesn't work with async iterators
# This may be destructive
async_stream_output = []
Expand Down
2 changes: 1 addition & 1 deletion guardrails/cli/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def generate_template_config(
guard_instantiations = []

for i, guard in enumerate(template["guards"]):
guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])")
guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])")
guard_instantiations = "\n".join(guard_instantiations)
# Interpolate variables
output_content = template_content.format(
Expand Down
2 changes: 1 addition & 1 deletion guardrails/cli/hub/template_config.py.template
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from guardrails import Guard
from guardrails import AsyncGuard, Guard
from guardrails.hub import {VALIDATOR_IMPORTS}

try:
Expand Down
26 changes: 15 additions & 11 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,12 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
error="The response from the server was empty!",
)

guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend([Call.from_interface(call) for call in guard_history])
# TODO reenable this when we have history support in
# multi-node server environments
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend([Call.from_interface(call) for call in guard_history])

validation_summaries = []
if self.history.last and self.history.last.iterations.last:
Expand Down Expand Up @@ -1281,13 +1283,15 @@ def _stream_server_call(
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)

# TODO reenable this when sever supports multi-node history
# if validation_output:
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend(
# [Call.from_interface(call) for call in guard_history]
# )
else:
raise ValueError("Guard does not have an api client!")

Expand Down
8 changes: 8 additions & 0 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def _invoke_llm(
),
)

# these are gr only and should not be getting passed to llms
kwargs.pop("reask_prompt", None)
kwargs.pop("reask_instructions", None)

response = completion(
model=model,
*args,
Expand Down Expand Up @@ -1088,6 +1092,10 @@ async def invoke_llm(
),
)

# these are gr only and should not be getting passed to llms
kwargs.pop("reask_prompt", None)
kwargs.pop("reask_instructions", None)

response = await acompletion(
*args,
**kwargs,
Expand Down
6 changes: 4 additions & 2 deletions guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def async_step(
_ = self.is_last_chunk(chunk, api)

fragment += chunk_text

results = await validator_service.async_partial_validate(
chunk_text,
self.metadata,
Expand All @@ -157,7 +158,8 @@ async def async_step(
"$",
True,
)
validators = self.validation_map["$"] or []
validators = self.validation_map.get("$", [])

# collect the result validated_chunk into validation progress
# per validator
for result in results:
Expand Down Expand Up @@ -210,7 +212,7 @@ async def async_step(
validation_progress[validator_log.validator_name] += chunk
# if there is an entry for every validator
# run a merge and emit a validation outcome
if len(validation_progress) == len(validators):
if len(validation_progress) == len(validators) or len(validators) == 0:
if refrain_triggered:
current = ""
else:
Expand Down
20 changes: 15 additions & 5 deletions guardrails/telemetry/guard_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,18 @@ def trace_stream_guard(
res = next(result) # type: ignore
# FIXME: This should only be called once;
# Accumulate the validated output and call at the end
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
if not guard_span.is_recording():
# Assuming you have a tracer instance
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"stream_guard_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
except StopIteration:
next_exists = False

Expand Down Expand Up @@ -180,6 +189,7 @@ def trace_guard_execution(
result, ValidationOutcome
):
return trace_stream_guard(guard_span, result, history)

add_guard_attributes(guard_span, history, result)
add_user_attributes(guard_span)
return result
Expand All @@ -204,14 +214,14 @@ async def trace_async_stream_guard(
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"new_guard_span", # type: ignore
"async_stream_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span

add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
yield res
except StopIteration:
next_exists = False
except StopAsyncIteration:
Expand Down
5 changes: 5 additions & 0 deletions guardrails/telemetry/runner_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs):
) as call_span:
try:
response = fn(*args, **kwargs)
if isinstance(response, LLMResponse) and (
response.async_stream_output or response.stream_output
):
# TODO: Iterate, add a call attr each time
return response
add_call_attributes(call_span, response, *args, **kwargs)
return response
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions guardrails/validator_service/validator_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def run_validator(

# requires at least 2 validators
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
if len(new_values) == 0:
return original
current = new_values.pop()
while len(new_values) > 0:
nextval = new_values.pop()
Expand Down
4 changes: 2 additions & 2 deletions server_ci/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from guardrails import Guard
from guardrails import AsyncGuard

try:
file_path = os.path.join(os.getcwd(), "guard-template.json")
Expand All @@ -11,4 +11,4 @@
SystemExit(1)

# instantiate guards
guard0 = Guard.from_dict(guards[0])
guard0 = AsyncGuard.from_dict(guards[0])
55 changes: 54 additions & 1 deletion server_ci/tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import openai
import os
import pytest
from guardrails import Guard, settings
from guardrails import AsyncGuard, Guard, settings

# OpenAI compatible Guardrails API Guard
openai.base_url = "http://127.0.0.1:8000/guards/test-guard/openai/v1/"
Expand Down Expand Up @@ -32,6 +32,59 @@ def test_guard_validation(mock_llm_output, validation_output, validation_passed,
assert validation_outcome.validated_output == validation_output


@pytest.mark.asyncio
async def test_async_guard_validation():
settings.use_server = True
guard = AsyncGuard(name="test-guard")

validation_outcome = await guard(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
temperature=0.0,
)

assert validation_outcome.validation_passed is True # noqa: E712
assert validation_outcome.validated_output == "Citrus fruit,"


@pytest.mark.asyncio
async def test_async_streaming_guard_validation():
settings.use_server = True
guard = AsyncGuard(name="test-guard")

async_iterator = await guard(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
stream=True,
temperature=0.0,
)

full_output = ""
async for validation_chunk in async_iterator:
full_output += validation_chunk.validated_output

assert full_output == "Citrus fruit,Citrus fruit,"


@pytest.mark.asyncio
async def test_sync_streaming_guard_validation():
settings.use_server = True
guard = Guard(name="test-guard")

iterator = guard(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
stream=True,
temperature=0.0,
)

full_output = ""
for validation_chunk in iterator:
full_output += validation_chunk.validated_output

assert full_output == "Citrus fruit,Citrus fruit,"


@pytest.mark.parametrize(
"message_content, output, validation_passed, error",
[
Expand Down
Loading