Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion .github/workflows/integration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,21 @@ jobs:
pipx install poetry
poetry install

- name: Cache llm2 models
uses: actions/cache/restore@v5
id: cache-llm2-models-restore
env:
cache-name: cache-llm2-models
with:
path: llm2-persistent_storage/
key: ${{ runner.os }}-llm2-models-${{ env.cache-name }}-${{ hashFiles('llm2/lib/main.py') }}

- name: Install and init backend
working-directory: ${{ env.APP_NAME }}/lib
env:
APP_VERSION: ${{ fromJson(steps.appinfo.outputs.result).version }}
run: |
poetry run python3 main.py > ../backend_logs 2>&1 &
APP_PERSISTENT_STORAGE="$(pwd)/../../llm2-persistent-storage/" poetry run python3 main.py > ../backend_logs 2>&1 &

- name: Register backend
run: |
Expand Down Expand Up @@ -156,6 +165,43 @@ jobs:
curl -u "$CREDS" -H "oCS-APIRequest: true" http://localhost:8080/ocs/v2.php/taskprocessing/task/$TASK_ID?format=json
[ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ]

- name: Cache llm2 models
uses: actions/cache/save@v5
env:
cache-name: cache-llm2-models
with:
path: llm2-persistent_storage/
key: ${{ steps.cache-llm2-models-restore.outputs.cache-primary-key }}

- name: Run streaming task
if: matrix.server-versions == 'master'
env:
CREDS: "admin:password"
run: |
set -x
TASK=$(curl -X POST -u "$CREDS" -H "oCS-APIRequest: true" -H "Content-type: application/json" http://localhost:8080/ocs/v2.php/taskprocessing/schedule?format=json --data-raw '{"input": {"input": "Count from 1 to 20 in words"},"type":"core:text2text", "appId": "test", "customId": "", "preferStreaming": true}')
echo $TASK
TASK_ID=$(echo $TASK | jq '.ocs.data.task.id')
NEXT_WAIT_TIME=0
TASK_STATUS='"STATUS_SCHEDULED"'
STREAMING_UPDATES=0
until [ $NEXT_WAIT_TIME -eq 35 ] || [ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ] || [ "$TASK_STATUS" == '"STATUS_FAILED"' ]; do
TASK=$(curl -u "$CREDS" -H "oCS-APIRequest: true" http://localhost:8080/ocs/v2.php/taskprocessing/task/$TASK_ID?format=json)
echo $TASK
TASK_STATUS=$(echo $TASK | jq '.ocs.data.task.status')
echo $TASK_STATUS
TASK_OUTPUT=$(echo $TASK | jq -r '.ocs.data.task.output.output // ""')
if [ -n "$TASK_OUTPUT" ] && [ "$TASK_STATUS" != '"STATUS_SUCCESSFUL"' ] && [ "$TASK_STATUS" != '"STATUS_FAILED"' ]; then
STREAMING_UPDATES=$((STREAMING_UPDATES+1))
echo "Streaming update detected (count: $STREAMING_UPDATES)"
fi
sleep $(( NEXT_WAIT_TIME++ ))
done
echo "Final status: $TASK_STATUS"
echo "Total streaming updates detected: $STREAMING_UPDATES"
[ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ]
[ $STREAMING_UPDATES -gt 0 ]

- name: Show logs
if: always()
run: |
Expand Down
2 changes: 1 addition & 1 deletion default_config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"Qwen3.5-9B-Q4_K_M": {
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n{user_prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n",
"loader_config": {
"n_ctx": 16384,
"n_ctx": 24000,
"max_tokens": 8192,
"stop": ["<|eot_id|>"],
"temperature": 0.7
Expand Down
6 changes: 4 additions & 2 deletions lib/change_tone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable

from streaming import StreamContext, run_runnable_with_streaming

class ChangeToneProcessor:

runnable: Runnable
Expand All @@ -33,10 +35,10 @@ class ChangeToneProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable

def __call__(self, input_data: dict) -> dict[str, Any]:
def __call__(self, input_data: dict, context: StreamContext | None = None) -> dict[str, Any]:
"""Process a single input"""
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format_prompt(text=input_data['input'], tone=input_data['tone']).to_string())
]
return {'output':self.runnable.invoke(messages).content }
return {'output': run_runnable_with_streaming(self.runnable, messages, context)}
22 changes: 15 additions & 7 deletions lib/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
"""A chat chain
"""
import json
from typing import Any, Optional
from typing import Any

from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_community.chat_models import ChatLlamaCpp
from langchain_core.runnables import Runnable

from streaming import StreamContext, run_runnable_with_streaming


class ChatProcessor:
"""
Expand All @@ -24,10 +23,19 @@ def __init__(self, runner: Runnable):
def __call__(
self,
inputs: dict[str, Any],
context: StreamContext | None = None,
) -> dict[str, str]:
system_prompt = inputs['system_prompt']
if inputs.get('memories'):
system_prompt += "\n\nYou can remember things from other conversations with the user. If they are relevant, take into account the following memories: \n" + "\n\n".join(inputs['memories']) + "\n\n"
return {'output': self.runnable.invoke(
[('human', system_prompt)] + [(message['role'], message['content']) for message in [json.loads(message) for message in inputs['history']]] + [('human', inputs['input'])]
).content}
messages = [('human', system_prompt)] + [
(message['role'], message['content'])
for message in [json.loads(message) for message in inputs['history']]
] + [('human', inputs['input'])]
return {
'output': run_runnable_with_streaming(
self.runnable,
messages,
context,
)
}
68 changes: 58 additions & 10 deletions lib/chatwithtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,34 @@
"""A chat chain
"""
import json
import hashlib
import pprint
import re
from random import randint
from typing import Any

from langchain_community.chat_models import ChatLlamaCpp
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.messages.ai import AIMessage

from streaming import StreamContext, run_runnable_with_streaming

def generate_tool_call(tool_call: dict):
content = '<tool_call>'
content += json.dumps({"name": tool_call['name'], "arguments": tool_call['args']})
content += '</tool_call>'
return content


def generate_tool_call_id(tool_call: dict) -> str:
stable_payload = json.dumps(
{
"name": tool_call.get("name"),
"args": tool_call.get("args", tool_call.get("arguments", {})),
},
sort_keys=True,
)
return hashlib.sha1(stable_payload.encode("utf-8")).hexdigest()[:16]

def try_parse_tool_calls(content: str):
"""Try parse the tool calls."""
tool_calls = []
Expand All @@ -40,7 +53,9 @@ def try_parse_tool_calls(content: str):
func['args'] = func['arguments']
del func['arguments']
if not 'id' in func:
func['id'] = str(randint(1, 10000000000))
func['id'] = generate_tool_call_id(func)
if 'type' not in func:
func['type'] = 'tool_call'
found = True
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
Expand All @@ -66,7 +81,9 @@ def try_parse_tool_calls(content: str):
func['args'] = func['arguments']
del func['arguments']
if not 'id' in func:
func['id'] = str(randint(1, 10000000000))
func['id'] = generate_tool_call_id(func)
if 'type' not in func:
func['type'] = 'tool_call'
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
pass
Expand All @@ -79,6 +96,32 @@ def try_parse_tool_calls(content: str):
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}


def strip_tool_calls_for_streaming(content: str) -> str:
sanitized = re.sub(r"<tool_call>.*?</tool_call>", "", content, flags=re.DOTALL)
sanitized = re.sub(r"```tool_call\n.*?\n```", "", sanitized, flags=re.DOTALL)

partial_markers = [index for index in (sanitized.find("<tool"), sanitized.find("```tool")) if index != -1]
if partial_markers:
sanitized = sanitized[:min(partial_markers)]

return re.sub(r"<\|im_end\|>$", "", sanitized)


def build_streaming_payload(content: str) -> dict[str, Any] | None:
payload: dict[str, Any] = {}
cleaned_output = strip_tool_calls_for_streaming(content)
parsed_response = try_parse_tool_calls(content)
tool_calls = parsed_response.get('tool_calls')

if cleaned_output:
payload['output'] = cleaned_output
if tool_calls:
payload['output'] = cleaned_output
payload['tool_calls'] = json.dumps(tool_calls)

return payload or None

class ChatWithToolsProcessor:
"""
A chat with tools processor that supports batch processing
Expand All @@ -89,7 +132,7 @@ class ChatWithToolsProcessor:
def __init__(self, runner: ChatLlamaCpp):
self.model = runner

def _process_single_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
def _process_single_input(self, input_data: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
system_prompt = """
{downstream_system_prompt}

Expand Down Expand Up @@ -150,15 +193,20 @@ def _process_single_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
messages.append(HumanMessage(content=''))

pprint.pprint(messages)
response = self.model.invoke(messages)
response_content = run_runnable_with_streaming(
self.model,
messages,
context,
stream_payload_transform=build_streaming_payload,
suppress_empty_stream_updates=True,
)

#if not response.tool_calls or len(response.tool_calls) == 0:
response = AIMessage(**try_parse_tool_calls(response.content))
response = AIMessage(**try_parse_tool_calls(response_content))

return {
'output': response.content,
'tool_calls': json.dumps(response.tool_calls)
}

def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
return self._process_single_input(inputs)
def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
return self._process_single_input(inputs, context)
8 changes: 5 additions & 3 deletions lib/contextwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable

from streaming import StreamContext, run_runnable_with_streaming

class ContextWriteProcessor:

runnable: Runnable
Expand All @@ -36,13 +38,13 @@ class ContextWriteProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable

def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
style_input=inputs['style_input'],
source_input=inputs['source_input']
))
]
output = self.runnable.invoke(messages)
return {'output': output.content}
output = run_runnable_with_streaming(self.runnable, messages, context)
return {'output': output}
9 changes: 6 additions & 3 deletions lib/free_prompt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later

from typing import Any, List
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import Runnable

from streaming import StreamContext, run_runnable_with_streaming


class FreePromptProcessor:
"""
Expand All @@ -20,9 +22,10 @@ def __init__(self, runnable: Runnable):
def __call__(
self,
inputs: dict[str, Any],
context: StreamContext | None = None,
) -> dict[str, Any]:
output = self.runnable.invoke([
output = run_runnable_with_streaming(self.runnable, [
SystemMessage(self.system_prompt),
HumanMessage(inputs['input'])
]).content
], context)
return {'output': output}
8 changes: 5 additions & 3 deletions lib/headline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import Runnable

from streaming import StreamContext, run_runnable_with_streaming


class HeadlineProcessor:
"""
Expand All @@ -33,12 +35,12 @@ class HeadlineProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable

def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
text=inputs['input']
))
]
output = self.runnable.invoke(messages)
return {'output': output.content}
output = run_runnable_with_streaming(self.runnable, messages, context)
return {'output': output}
Loading
Loading