diff options
author | David Timber <dxdt@dev.snart.me> | 2024-11-19 13:15:02 +0100 |
---|---|---|
committer | David Timber <dxdt@dev.snart.me> | 2024-11-19 13:15:02 +0100 |
commit | c01c6586bc1f79510688f35824c7049172063b58 (patch) | |
tree | 3627a47ffd83b846bd5bc76f626ff11c6816a56d /shitoutcode/__init__.py |
Diffstat (limited to 'shitoutcode/__init__.py')
-rw-r--r-- | shitoutcode/__init__.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/shitoutcode/__init__.py b/shitoutcode/__init__.py new file mode 100644 index 0000000..67da767 --- /dev/null +++ b/shitoutcode/__init__.py @@ -0,0 +1,86 @@ +import openai +import pyjson5 + + +class Env: + def __init__(self): + self.model = "gpt-4o" + self.output_tokens_min = 250 + self.output_tokens_max = 500 + self.temp: float = None + self.seed: int = None + self.max_tokens = 8000 + +class SourceFile: + def __init__(self): + self.name: str = None + self.category: str = None + self.contents: str = None + +class LLMAPIException (Exception): ... + +def __do_prompt (lang: str, extra_prompt: str, env: Env): + prompt = ''' +Write me a program source code written in %s. Write anything you'd like. +Use more than %d words, but no more than %d words in the code.''' % ( + lang, env.output_tokens_min, env.output_tokens_max) + if extra_prompt: + prompt = ' ' + extra_prompt + + messages = [ + { + "role": "system", + "content": "You're a helpful assistant that writes " + + "computer program of any kind" + }, + { "role": "user", "content": prompt } + ] + functions = [ + { + "name": "get_result", + "description": "Output the source code", + "parameters": { + "type": "object", + "properties": { + "cat": { + "type": "string", + "description": "one word describing program functions" + }, + "filename": { + "type": "string", + "description": "The source code file name" # TODO: strip unallowed characters in Unix fs + }, + "code": { + "type": "string", + "description": "The source code" + } + } + } + } + ] + + return openai.chat.completions.create( + model = env.model, + messages = messages, + functions = functions, + function_call = "auto", + temperature = env.temp, + seed = env.seed, + max_tokens = env.max_tokens + ) + +def gen_rand_srccode (lang: str, extra_prompt: str, env: Env) -> SourceFile | str: + rsp = __do_prompt(lang, extra_prompt, env) + match rsp.choices[0].finish_reason: + case 'stop': + return rsp.choices[0].message.content + case 'function_call': + choice = pyjson5.loads(rsp.choices[0].message.function_call.arguments) + + ret = SourceFile() + ret.name = choice["filename"] + ret.category = choice.get("cat") + ret.contents = choice["code"] + + return ret + case _: raise LLMAPIException(rsp.choices[0].finish_reason) |