module updates
This commit is contained in:
parent
a6289289ee
commit
f4b433fbc1
@ -61,6 +61,7 @@ google-auth-oauthlib
|
||||
#langchain[all]
|
||||
langchain[llms]
|
||||
langchain-community
|
||||
langchain-classic
|
||||
#langchain-cli
|
||||
#langchain-core
|
||||
langchain-openai
|
||||
|
||||
@ -14,10 +14,10 @@ logging = logger.get_logger(
|
||||
log_format='default'
|
||||
)
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from langchain.chains import GraphCypherQAChain
|
||||
from langchain_classic.chains import GraphCypherQAChain
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain_classic.prompts.prompt import PromptTemplate
|
||||
from routers.llm.private.ollama.ollama_wrapper import OllamaWrapper
|
||||
from modules.database.tools.neontology.utils import get_node_types, get_rels_by_type
|
||||
from modules.database.tools.neontology.basenode import BaseNode
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
from ollama import Client
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
|
||||
class OllamaWrapper(Runnable):
|
||||
def __init__(self, host: str, model: str = "llama3.2:latest"):
|
||||
@ -9,8 +8,11 @@ class OllamaWrapper(Runnable):
|
||||
self.model = model
|
||||
|
||||
def invoke(self, prompt: Any, config: Dict[str, Any] = None, **kwargs: Any) -> str:
|
||||
if isinstance(prompt, StringPromptValue):
|
||||
# Handle different prompt types (StringPromptValue, str, etc.)
|
||||
if hasattr(prompt, 'to_string'):
|
||||
prompt = prompt.to_string()
|
||||
elif not isinstance(prompt, str):
|
||||
prompt = str(prompt)
|
||||
|
||||
# Use the model from constructor, but allow override via kwargs
|
||||
model_name = kwargs.get("model", self.model)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user